In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
import torchvision
import os
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
import time
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
#from models import * #model and dataset
from train import *

if torch.cuda.is_available():  
    torch.cuda.empty_cache()
    device = "cuda:0" 
else:  
    device = "cpu" 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dim = 128
K = 65536
m = 0.999
T = 0.07
epoch = 100
batch = 8
imgpath = 'Dataset/train_images'

# Initialize

modelq = encoder().to(device)
modelk = encoder().to(device)
dataset = MocoSet(imgpath)
loader = DataLoader(dataset,batch,shuffle=True)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(modelq.parameters(), lr=1e-3)

with torch.no_grad():
    for q,k in zip(modelq.parameters(),modelk.parameters()):
        k.data.copy_(q.data)
        k.requires_grad = False

kqueue = torch.randn(dim,K).to(device)
kqueue = nn.functional.normalize(kqueue, dim=0)

lossplt = []
bestloss = 10000

for i in range(epoch):
    modelq.train()
    total_loss=0.0
    for inputq,inputk in tqdm(loader):
        inputq, inputk = inputq.to(device),inputk.to(device)
        q = modelq(inputq)
        q = nn.functional.normalize(q, dim=1)
        with torch.no_grad():
            k = modelk(inputk)
            k = nn.functional.normalize(k, dim=1)
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        l_neg = torch.einsum('nc,ck->nk', [q, kqueue.clone().detach()])
        logits = torch.cat([l_pos, l_neg], dim=1)
        logits /= T
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
        loss = loss_function(logits,labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss+=loss
        
        with torch.no_grad():
            for pq,pk in zip(modelq.parameters(),modelk.parameters()):
                pk.data = pk.data*m+pq.data*(1-m)
        
        kqueue = torch.cat([k.T,kqueue[:,:kqueue.size(1)-k.size(0)]],dim=1)
    total_loss/=len(dataset)
    if total_loss < bestloss:
        bestloss = total_loss
        torch.save(modelq,'encoder.pt')
    lossplt.append(total_loss.cpu().detach().numpy())
    print('Loss: ',total_loss.item())
    plt.plot(lossplt,label='loss')
    plt.legend()
    plt.savefig('pretrainloss.png')
    plt.close()


  0%|          | 0/113 [00:02<?, ?it/s]

torch.Size([128, 65536])





In [7]:
print(len(dataset))


900
