In [None]:
import datasets
from IPython.display import display
from sentence_transformers import SentenceTransformer
import torch
from transformers import AutoTokenizer
import torchvision.models as models
import matplotlib.pyplot as plt
import copy
import numpy as np
import os
import pickle
from torchsummary import summary
from textwrap import wrap
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
path_prepend = './drive/MyDrive/project/'
%matplotlib inline

In [None]:
dataset = datasets.load_dataset('poloclub/diffusiondb', '2m_random_10k', split="train")

In [None]:
fig,ax = plt.subplots(3,3, figsize=(25,25))
ax = ax.ravel()
for i in range(9):
  rand = np.random.randint(0,len(dataset))
  image = dataset[rand]['image']
  ax[i].imshow(image)
  ax[i].set_title("\n".join(wrap(dataset[rand]['prompt'])))

fig.tight_layout()
plt.show()

In [None]:
if os.path.exists(path_prepend+'embeddings.pkl'):
  with open(path_prepend+'embeddings.pkl','rb') as fIn:
    stored_data = pickle.load(fIn)
    if stored_data['sentences'] == dataset['prompt']:
      embeddings = stored_data['embeddings']

if 'embeddings' not in locals():
  embeddings = SentenceTransformer('all-miniLM-L6-v2', device=device).encode(dataset['prompt'])
  with open(path_prepend+'embeddings.pkl', "wb") as fOut:
    pickle.dump({'sentences': dataset['prompt'], 'embeddings': embeddings}, fOut, protocol=pickle.HIGHEST_PROTOCOL)
print(f'Size of embeddings: {embeddings.shape}')

embeddings = [embedding for embedding in embeddings]
dataset = dataset.add_column("embeddings",embeddings)

In [None]:
def transforms(examples):
    examples["x"]  = [np.moveaxis(np.asarray(image.convert("RGB").resize((512,512))),2,0) for image in examples["image"]]
    examples["y"] = examples["embeddings"]
    return examples

dataset = dataset.map(transforms, remove_columns=["image", "embeddings"], batched=True)

In [None]:
print(f'Image size: {dataset[0]["x"].shape}')

In [None]:
dataset.set_format(type='torch', columns=['x','y'])
dataset = dataset.train_test_split(test_size = 0.2, shuffle = False)
dataloaders = {'train':torch.utils.data.DataLoader(dataset['train'], batch_size=16), 'test':torch.utils.data.DataLoader(dataset['test'], batch_size=32)}
dataset_sizes = {'train':len(dataset['train']), 'test':len(dataset['test'])}

In [None]:
model = models.resnet34(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, len(embeddings[0]))
summary(model)

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

In [None]:
num_epochs = 5

In [None]:
def loss_fn(y, y_pred):
  output = torch.mean(torch.nn.functional.cosine_similarity(y, y_pred, dim=1))
  return (1/output)-1

In [None]:
def checkpoint(model, path):
  torch.save({
      'optimizer': optimizer.state_dict(),
      'model':model.state_dict(),
      }, path)

def resume(model, path):
  checkpoint = torch.load(path)
  model.load_state_dict(checkpoint['model'])
  optimizer.load_state_dict(checkpoint['model'])

In [None]:
best_loss = torch.inf
start_epoch = 0
if start_epoch > 0:
  resume_epoch = start_epoch - 1
  resume(model, path_prepend+f'epoch-{resume_epoch}.pth')

for epoch in range(start_epoch, num_epochs):
  print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  print('-' * 10)
  for phase in ['train', 'test']:
    if phase == 'train':
      scheduler.step()
      model.train()
    else:
      model.eval()
    running_loss = 0.0
    it = iter(dataloaders[phase])
    try:
      while(True):
        data = next(it)
        x = data['x'].to(device).type(torch.float32)
        y = data['y'].to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(phase == 'train'):
          outputs = model(x)
          loss = loss_fn(y, outputs)

          if phase == 'train':
            loss.backward()
            optimizer.step()
        running_loss += loss.data*x.size(0)

    except StopIteration:
      epoch_loss = running_loss/dataset_sizes[phase]
      print('{} Loss:{:.4f}'.format(phase,epoch_loss))
      if phase == 'test' and epoch_loss < best_loss:
          best_loss = epoch_loss
          best_model_wts = copy.deepcopy(model.state_dict())
          checkpoint(model, path_prepend+"best-model.pth")
  print(f'End of epoch {epoch}')
  checkpoint(model, path_prepend+f'epoch-{epoch}.pth')

print('Best Loss: {:4f}'.format(best_loss))
model.load_state_dict(best_model_wts)