In [1]:
from autoencodeur import *
import torch
import PIL
import numpy as np
import tqdm
import wandb
import random
import os
from refiners.fluxion import utils

os.environ['WANDB_NOTEBOOK_NAME'] = '/home/daniel/work/vton/notebooks/autoencoder_cats_training.ipynb'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Fix random seed
seed = 42
random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f8712709b50>

In [2]:
# Load dataset
path_dataset = "data_cats/"
image_dataset = ImageDataset(path_dataset).data
random.shuffle(image_dataset)

# Pass images to tensor
dataset = []
for i in tqdm.tqdm(range(len(image_dataset))):
    dataset.append(utils.image_to_tensor(image_dataset[i]).to(device))


# Split dataset
train_dataset = dataset[:int(len(dataset) * 0.7)]
val_dataset = dataset[int(len(dataset) * 0.7):int(len(dataset) * 0.9)]
test_dataset = dataset[int(len(dataset) * 0.9):]



100%|██████████| 11733/11733 [00:11<00:00, 1006.11it/s]


In [12]:
lr = 1e-4
num_epochs = 20
dropout = 0.05

# Initialize wandb
wandb.init(project='cats', entity = "finegrain-cs", name = '32latentspace_test', config={
    'lr': lr,
    'num_epochs': num_epochs,
    'seed': seed,
    'dropout': dropout,
})

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,▇█▅▅▃▃▄▄▃▄▅▄▃▃▂▃▃▂▂▃▃▂▃▄▂▃▂▂▃▃▁▂▂▂▃▂▁▃▁▃

0,1
epoch,0.0
loss,46.0822


In [13]:
autoencoder = AutoEncoder()
load_dropout(autoencoder, dropout=0.1)
autoencoder.to(device)
optimizer = torch.optim.Adam(autoencoder.parameters() , lr=lr)

In [14]:
# Training
for epoch in range(num_epochs):
    autoencoder.train()
    for image in tqdm.tqdm(train_dataset):
        y = autoencoder(image)
        loss = (y-image).norm()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        wandb.log({'epoch': epoch,  'loss': loss.item()})
    print(f"epoch : {epoch} ,loss: {loss.item()}")


    # Testing at the end of training
    loss_val = 0
    reconstructed_images = []

    with torch.no_grad():
        autoencoder.eval()
        for k, images_val in enumerate(val_dataset):
            y_val = autoencoder(images_val)
            l = (y_val - images_val).norm()
            loss_val += l

            # # Append the reconstructed images for visualization
            if k < 10:
                image_shape = images_val.shape
                concat = Image.new('RGB', (image_shape[-1]*2, image_shape[-2]))
                concat.paste(utils.tensor_to_image(images_val.data), (0, 0))
                concat.paste(utils.tensor_to_image(y_val.data), (image_shape[-1], 0))
                reconstructed_images.append(concat)

    images = [PIL.Image.fromarray(np.array(image)) for image in reconstructed_images]

    wandb.log({f"reconstructed_images": [wandb.Image(image) for image in images]})
    wandb.log({'epoch': epoch, 'val_loss': (loss_val / len(val_dataset)).item()})



100%|██████████| 8213/8213 [17:42<00:00,  7.73it/s]


epoch : 0 ,loss: 38.1009521484375


100%|██████████| 8213/8213 [20:19<00:00,  6.74it/s]


epoch : 1 ,loss: 26.914196014404297


100%|██████████| 8213/8213 [20:48<00:00,  6.58it/s]


epoch : 2 ,loss: 23.813940048217773


100%|██████████| 8213/8213 [21:22<00:00,  6.40it/s]


epoch : 3 ,loss: 24.08124542236328


100%|██████████| 8213/8213 [21:32<00:00,  6.36it/s]


epoch : 4 ,loss: 21.566354751586914


  0%|          | 39/8213 [00:05<18:31,  7.35it/s]


KeyboardInterrupt: 

In [15]:
torch.save(autoencoder.state_dict(), 'autoencodeur.pth')

In [16]:
model = AutoEncoder()
model.load_state_dict(torch.load('autoencodeur.pth'))


<All keys matched successfully>