In [1]:
# !pip install -q torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
# !pip install -q imagen-pytorch==1.17.0
# !pip install -q yagmail[all]
# !pip uninstall -q lxml
# !pip install -q lxml

import torch
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer
from nslt_dataset import NSLT
import torchvision.transforms as T
import ipywidgets as widgets
from emailme import send_email

[0m

Downloading:   0%|          | 0.00/605 [00:00<?, ?B/s]

    https://github.com/beartype/beartype#pep-585-deprecations
  warn(


In [2]:
# send_email('./samples/', ['sample-22_book.gif'], 'yeaaah')

In [3]:
unet1 = Unet3D(
    dim = 128, 
    dim_mults = (1, 2, 4, 8),
    
    lowres_cond = True,
    image_embed_dim = 32,
).cuda()

unet2 = Unet3D(
    dim = 128, 
    dim_mults = (1, 2, 4, 8),
    
    lowres_cond = True,
    image_embed_dim = 32,
).cuda()

# elucidated imagen, which contains the unets above (base unet and super resoluting ones)

imagen = ElucidatedImagen(
    unets = (unet1, unet2),
    image_sizes = (16, 64),
    random_crop_sizes = (None, 8),
    num_sample_steps = 10,
    cond_drop_prob = 0.1,
    sigma_min = 0.002,                          # min noise level
    sigma_max = (80, 160),                      # max noise level, double the max noise level for upsampler
    sigma_data = 0.5,                           # standard deviation of data distribution
    rho = 7,                                    # controls the sampling schedule
    P_mean = -1.2,                              # mean of log-normal distribution from which noise is drawn for training
    P_std = 1.2,                                # standard deviation of log-normal distribution from which noise is drawn for training
    S_churn = 30,                               # parameters for stochastic sampling - depends on dataset, Table 5 in apper
    S_tmin = 0.01,
    S_tmax = 1,
    S_noise = 1.007,
    # text_encoder_name = 't5-large',
    
).cuda()

In [4]:
mode = 'rgb'
root = {'word': './WLASL2000/'}

dataset = NSLT('./preprocess/nslt_100.json', 'train', root=root, mode='rgb', transforms=None)
val_dataset = NSLT('./preprocess/nslt_100.json', 'test', root=root, mode='rgb', transforms=None)

Downloading:   0%|          | 0.00/945M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.74k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.81k [00:00<?, ?B/s]

Skipped videos:  0
1780
Skipped videos:  0
258


In [5]:
# for x,y in enumerate(dataset, 0):
#     if x<3:
#         print(y[0].size())
#         print(y[1])
#     else:
#         break

In [6]:
batch_size = 4

trainer = ImagenTrainer(
    imagen,
    lr = 1e-4,
    
).cuda()
trainer.add_train_dataset(dataset, batch_size=batch_size)
trainer.add_valid_dataset(val_dataset, batch_size=batch_size)

In [7]:
texts = [
    'book',
    'deaf',
    'fine',
    'yes',
    'cool',
]
# trainer.save(f'./checkpoint{1}.pt')
# trainer.load(f'checkpoint{1}.pt')
# trainer.load(f'checkpoint{2}.pt')
# loss = trainer.train_step(unet_number = unet_training, max_batch_size = 4)
print('done')

done


In [None]:
def go():
    unet_training = 2
    trainer.load(f'checkpoint{2}.pt')
    max_batch_size = 4
    running_totals = []
    overview = []
    
    for i in range(100000):
        loss = trainer.train_step(unet_number = unet_training, max_batch_size = max_batch_size)
        running_totals.append(loss)

        if not (i % 250) and not i==0:
            print(f'avg_loss: {sum(running_totals[-100:])/100}')
        if not (i % 500) and not i==0:
            valid_loss = trainer.valid_step(unet_number = unet_training, max_batch_size = max_batch_size)
            print(f'valid loss: {valid_loss}, total avg loss: {sum(running_totals[-500:])/500}')

        if not (i % 1000) and not i==0 and trainer.is_main: # is_main makes sure this can run in distributed
            overview_total = sum(running_totals)/len(running_totals)
            print(f'on step {i}, avg loss of last 1000 steps: {overview_total}')
            del running_totals[:]
            overview.append(overview_total)
            videos = trainer.sample(texts = texts, video_frames = 20, stop_at_unet_number=unet_training)
            # catch the images and animate them.
            pil_images = list(map(    lambda img: list(map(T.ToPILImage(), img.unbind(dim = 0))),    videos.swapdims(2,1)    ))
            for x,img in enumerate(pil_images):
                img[0].save(f'./samples/sample-{i//1000}_{texts[x]}.gif', format='GIF', save_all = True, 
                            append_images = pil_images[x][1:],
                            optimize = False, duration = 200, loop=0,
                            disposal=3,
                           )
            display(widgets.Image(value = open(f'./samples/sample-{i//1000}_book.gif','rb').read(),
                 format='gif',
                 width=512,
                 height=512,
                 ))
            display(widgets.Image(value = open(f'./samples/sample-{i//1000}_yes.gif','rb').read(),
                 format='gif',
                 width=512,
                 height=512,
                 ))
            
        if not (i % 2000) and not (i==0):
            trainer.save(f'./checkpoint{unet_training}.pt')
            
        if not (i%4000):
            updates = "updates:\n" + "\n".join([f'{x:.4f}' for x in overview])
            send_email('./samples/', [f'sample-{i//1000}_{x}.gif' for x in texts], updates)

    trainer.save(f'./checkpoint{unet_training}.pt')

go()

checkpoint loaded from checkpoint2.pt
avg_loss: 0.1657245621085167
avg_loss: 0.20621512413024903
valid loss: 0.11265818029642105, total avg loss: 0.18814878346025943
avg_loss: 0.1645312238484621
avg_loss: 0.16919931482523679
valid loss: 0.2581879496574402, total avg loss: 0.17348799180984498
on step 1000, avg loss of last 1000 steps: 0.18076613912364461


0it [00:00, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xce\xe6\xe2\xb2\xe2\xd8\xad\xe2\xd5\xaa\xe4\xd2\xa9\xdc\xce\xa4\xe7…

Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\x9f\xed\xe4\x9e\xe6\xdd\x94\xe9\xd6\x9b\xe1\xd9\x95\xe2\xd6\x93\xe1…

avg_loss: 0.1678640579432249
avg_loss: 0.18788486570119858
valid loss: 0.1605975478887558, total avg loss: 0.18092206566780805
avg_loss: 0.17355257600545884
avg_loss: 0.16076122596859932
valid loss: 0.08762858808040619, total avg loss: 0.17338728019595145
on step 2000, avg loss of last 1000 steps: 0.17715467293187975


0it [00:00, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xa5\xe5\xc9\xa1\xe1\xcd\xa0\xe0\xca\xa4\xdd\xc6\x9c\xdd\xcc\xad\xd9…

Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xbf\xe4\xd1\xac\xe1\xd7\xb4\xda\xcd\xb6\xd1\xcc\xa3\xd0\xc9\x90\xdc…

checkpoint saved to ./checkpoint2.pt
avg_loss: 0.1710332153737545
avg_loss: 0.1903467185050249
valid loss: 0.17666348814964294, total avg loss: 0.1765720390677452
avg_loss: 0.2141426555067301
avg_loss: 0.17541252456605436
valid loss: 0.17546993494033813, total avg loss: 0.1799412784576416
on step 3000, avg loss of last 1000 steps: 0.1782566587626934


0it [00:00, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xb8\xec\xeb\x9c\xdf\xda\x90\xdf\xd9\x8f\xda\xd0\xa1\xd3\xd2\x94\xd3…

Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xba\xf3\xe4\xac\xee\xdf\xa4\xf1\xe5\xa0\xec\xdc\xa2\xea\xe0\xa7\xe4…

avg_loss: 0.16432345174252988
avg_loss: 0.16644683688879014
valid loss: 0.21089644730091095, total avg loss: 0.17458262219280005
avg_loss: 0.17761773619800805
avg_loss: 0.18426691085100175
valid loss: 0.06389383971691132, total avg loss: 0.16788379007577897
on step 4000, avg loss of last 1000 steps: 0.1712332061342895


0it [00:00, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

sampling time step:   0%|          | 0/10 [00:00<?, ?it/s]

Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xa3\xea\xe2\x97\xdd\xd9\xa7\xda\xda\x8e\xd9\xd1\x9d\xd6\xd5\x8e\xd6…

Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xc8\xe6\xe0\xb7\xe4\xd8\xc3\xda\xdc\xb6\xdb\xdd\xaa\xde\xd5\xa7\xd9…

checkpoint saved to ./checkpoint2.pt
avg_loss: 0.18532417714595795
avg_loss: 0.17001347687095403
valid loss: 0.13566994667053223, total avg loss: 0.17616512279957533


In [None]:
# pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116

In [None]:
videos = trainer.sample(texts = texts, video_frames = 40) # extrapolating to 20 frames from training on 10 frames

videos.shape # (4, 3, 20, 32, 32)

In [None]:
videos[0].shape

In [None]:
pil_images = list(map(    lambda img: list(map(T.ToPILImage(), img.unbind(dim = 0))),    videos.swapdims(2,1)    ))
for x,img in enumerate(pil_images):
    img[0].save(f'{texts[x]}.gif', format='GIF', save_all = True, 
                append_images = pil_images[x][1:],
                optimize = False, duration = 100, loop=0,
                disposal=3,
               )
    print(f'{texts[x]} signed:')
    display(widgets.Image(value = open(f'{texts[x]}.gif','rb').read(),
             format='gif',
             width=512,
             height=512,
             ))

In [None]:
# trainer = ImagenTrainer(imagen).cuda()
# trainer.add_train_dataset(dataset, batch_size=4)
# trainer.add_valid_dataset(val_dataset, batch_size=4)
# unet_training = 2
# trainer.load(f'./checkpoint{1}.pt')
# for i in range(1000):
#     loss = trainer.train_step(unet_number = unet_training, max_batch_size = 4)
#     print(f'loss: {loss}')

#     if not (i % 50):
#         valid_loss = trainer.valid_step(unet_number = unet_training, max_batch_size = 4)
#         print(f'valid loss: {valid_loss}')

#     if not (i % 100) and trainer.is_main: # is_main makes sure this can run in distributed
#         videos = trainer.sample(texts = texts, video_frames = 20, stop_at_unet_number=unet_training)
#         pil_images = list(map(lambda img: list(map(T.ToPILImage(), img.unbind(dim = 0))), videos.swapdims(2,1)))

#         for x in range(5):
#             display(pil_images[0][x*4])
#             pil_images[0][x*4].save(f'./samples/sample-{i // 100}_frame-{x*4}.png')

# trainer.save(f'./checkpoint{unet_training}.pt')