In [1]:
from models.BYOL2_model import BYOL2
from data.custom_transforms import BatchTransform, ListToTensor, PadToSquare, SelectFromTuple
from data.pairs_dataset import PairsDataset, pair_collate_fn
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import sys
import torch
from torch.utils.data import Subset
import torchvision.models as models
import torchvision.transforms as T
import warnings

In [2]:
train_dataset = PairsDataset(
    '/users/jmorales/Shoes/images_train/',
    '/users/jmorales/Shoes/images_train_pidinet/'
)

transforms_1 = T.Compose([
    BatchTransform(SelectFromTuple(0)),
    BatchTransform(PadToSquare(255)),
    BatchTransform(T.Resize((224,224))),
    ListToTensor('cuda', torch.float),
])
transforms_2 = T.Compose([
    BatchTransform(SelectFromTuple(1)),
    BatchTransform(PadToSquare(255)),
    BatchTransform(T.Resize((224,224))),
    ListToTensor('cuda', torch.float),
])

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=50,
    shuffle=True,
    collate_fn=pair_collate_fn,
    num_workers=4
)

In [6]:
encoder = models.resnet50(pretrained=False)
encoder.load_state_dict(torch.load('../checkpoints/resnet50_byol_quickdraw_128_1000_v3.pt'))
empty_transform = T.Compose([])
epochs = 5
epoch_size = len(train_loader)
learner = BYOL2(
    encoder,
    image_size=224,
    hidden_layer='avgpool',
    augment_fn=empty_transform,
    cosine_ema_steps=epochs*epoch_size
)
learner.augment1 = transforms_1
learner.augment2 = transforms_2
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

In [7]:
learner.load_state_dict(torch.load('../checkpoints/bimodal_byol_resnet50_pretrained_sketchy_v5.pt'))
torch.save(learner.state_dict(), '../checkpoints/self_bimodal_byol_sketchy_then_shoes_{}epochs.pt'.format(epochs))

In [8]:
learner.load_state_dict(torch.load('../checkpoints/self_bimodal_byol_sketchy_then_shoes_{}epochs.pt'.format(epochs)))
learner = learner.to('cuda')
learner.train()
filehandler = open('../checkpoints/training_bimodal_byol.txt', 'w')
with warnings.catch_warnings():
    warnings.filterwarnings('ignore')
    running_loss = np.array([], dtype=np.float32)
    for epoch in range(epochs):
        i = 0
        for images in train_loader:
            loss = learner(images) #.to('cuda', dtype=torch.float))
            opt.zero_grad()
            loss.backward()
            opt.step()
            learner.update_moving_average()
            running_loss = np.append(running_loss, [loss.item()])
            sys.stdout.write('\rEpoch {}, batch {} - loss {:.4f}'.format(epoch+1, i+1, np.mean(running_loss)))
            filehandler.write('Epoch {}, batch {} - loss {:.4f}\n'.format(epoch+1, i+1, np.mean(running_loss)))
            filehandler.flush()
            i += 1
            if i%(epoch_size/2)==0:
                torch.save(learner.state_dict(), '../checkpoints/self_bimodal_byol_sketchy_then_shoes_{}epochs.pt'.format(epochs))
        running_loss = np.array([], dtype=np.float32)
        sys.stdout.write('\n')
filehandler.close()

Epoch 1, batch 36 - loss 0.2459
Epoch 2, batch 36 - loss 0.1377
Epoch 3, batch 36 - loss 0.0938
Epoch 4, batch 36 - loss 0.0688
Epoch 5, batch 36 - loss 0.0543
