# Training with Contrastive Learning

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
#from torchvision import datasets, models, transforms
from torchvision.models import resnet18,ResNet18_Weights
from torchvision.transforms import v2
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import tqdm
from glob import glob
import random
from torchvision.io import read_image, ImageReadMode

Here I just get a ResNet18 model, and replace its end with a 256-dimensional output layer

In [2]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

model = resnet18(weights=ResNet18_Weights.DEFAULT)
num_feats = model.fc.in_features
model.fc = nn.Linear(num_feats, 256)

In [3]:
transform = v2.Compose([
    v2.ToImage(),
    v2.RandomResizedCrop(size=(200,200), antialias=True),
    v2.RandomRotation(30),
    v2.RandomPerspective(),
    v2.ToDtype(torch.float32, scale=True)
])

class ContLearnDataset(Dataset):
    def __init__(self, dir, transforms=None):
        self.dir = dir
        self.transforms = transforms
        self.filenames = glob(dir+'/*.jpg')

    def __getitem__(self, idx):
        a = read_image(self.filenames[idx], mode=ImageReadMode.RGB)
        p = a.detach().clone()
        i = torch.randint(len(self.filenames),(1,))[0]
        while i==idx:
            i = torch.randint(len(self.filenames),(1,))[0]
        n = read_image(self.filenames[i], mode=ImageReadMode.RGB)
        return self.transforms(a), self.transforms(p), self.transforms(n)

    def __len__(self):
        return len(self.filenames)

In [4]:
cld = ContLearnDataset('/home/scs/taylor/faculty/midPics',transforms=transform)

dl = DataLoader(cld, batch_size=64, num_workers=10)

In [5]:
model=model.to('cuda')
EPOCHS = 501

criterion = nn.TripletMarginLoss()
optimizer = optim.Adam(model.parameters(), lr=.001)

for epoch in tqdm.tqdm(range(EPOCHS)):
    totalloss=0
    for batch, (a, p, n) in enumerate(dl):
        a,p,n = a.to('cuda'), p.to('cuda'), n.to('cuda')
        aem = model(a)
        pem = model(p)
        nem = model(n)
        loss = criterion(aem, pem, nem)

            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        totalloss+=loss.item()
    if epoch%10==0:
        print(totalloss)
torch.save(model, 'fr.pt')

  0%|                                         | 1/501 [00:14<1:57:14, 14.07s/it]

27.044744189828634


  2%|▉                                       | 11/501 [02:32<1:53:17, 13.87s/it]

5.7630290780216455


  4%|█▋                                      | 21/501 [04:51<1:51:05, 13.89s/it]

5.274194166064262


  6%|██▍                                     | 31/501 [07:10<1:48:49, 13.89s/it]

5.330866588279605


  8%|███▎                                    | 41/501 [09:28<1:46:23, 13.88s/it]

2.3687154948711395


 10%|████                                    | 51/501 [11:49<1:45:09, 14.02s/it]

2.302560755982995


 12%|████▊                                   | 61/501 [14:08<1:42:17, 13.95s/it]

1.5873109828680754


 14%|█████▋                                  | 71/501 [16:28<1:39:48, 13.93s/it]

1.8012434244155884


 16%|██████▍                                 | 81/501 [18:50<1:39:42, 14.25s/it]

1.8421833366155624


 18%|███████▎                                | 91/501 [21:11<1:36:47, 14.16s/it]

1.2791936565190554


 20%|███████▊                               | 101/501 [23:33<1:33:37, 14.04s/it]

1.3381973393261433


 22%|████████▋                              | 111/501 [25:54<1:31:22, 14.06s/it]

1.7387591637670994


 24%|█████████▍                             | 121/501 [28:14<1:28:57, 14.05s/it]

0.8022745810449123


 26%|██████████▏                            | 131/501 [30:35<1:27:00, 14.11s/it]

0.9764685370028019


 28%|██████████▉                            | 141/501 [32:56<1:24:23, 14.06s/it]

0.6003412045538425


 30%|███████████▊                           | 151/501 [35:16<1:21:23, 13.95s/it]

0.4993991144001484


 32%|████████████▌                          | 161/501 [37:35<1:19:11, 13.97s/it]

1.1437131464481354


 34%|█████████████▎                         | 171/501 [39:55<1:16:57, 13.99s/it]

0.8815404213964939


 36%|██████████████                         | 181/501 [42:15<1:14:33, 13.98s/it]

0.5106975063681602


 38%|██████████████▊                        | 191/501 [44:35<1:12:32, 14.04s/it]

0.7315564006567001


 40%|███████████████▋                       | 201/501 [46:56<1:10:05, 14.02s/it]

0.619711622595787


 42%|████████████████▍                      | 211/501 [49:15<1:07:24, 13.95s/it]

0.6931751146912575


 44%|█████████████████▏                     | 221/501 [51:35<1:05:00, 13.93s/it]

0.4532166160643101


 46%|█████████████████▉                     | 231/501 [53:54<1:02:33, 13.90s/it]

0.4766850247979164


 48%|██████████████████▊                    | 241/501 [56:13<1:00:14, 13.90s/it]

0.9874532371759415


 50%|████████████████████▌                    | 251/501 [58:32<58:05, 13.94s/it]

0.27798614650964737


 52%|████████████████████▎                  | 261/501 [1:00:51<55:33, 13.89s/it]

0.5396388620138168


 54%|█████████████████████                  | 271/501 [1:03:10<53:15, 13.89s/it]

0.4586564600467682


 56%|█████████████████████▊                 | 281/501 [1:05:29<50:58, 13.90s/it]

0.4623073786497116


 58%|██████████████████████▋                | 291/501 [1:07:48<48:39, 13.90s/it]

0.6000809408724308


 60%|███████████████████████▍               | 301/501 [1:10:08<46:43, 14.02s/it]

0.49066393077373505


 62%|████████████████████████▏              | 311/501 [1:12:28<44:24, 14.02s/it]

0.5984493792057037


 64%|████████████████████████▉              | 321/501 [1:14:48<42:02, 14.01s/it]

0.5013175867497921


 66%|█████████████████████████▊             | 331/501 [1:17:08<39:21, 13.89s/it]

0.5686188861727715


 68%|██████████████████████████▌            | 341/501 [1:19:28<37:17, 13.99s/it]

0.46712958067655563


 70%|███████████████████████████▎           | 351/501 [1:21:47<34:55, 13.97s/it]

0.5566348768770695


 72%|████████████████████████████           | 361/501 [1:24:08<32:49, 14.07s/it]

0.09223124757409096


 74%|████████████████████████████▉          | 371/501 [1:26:26<30:03, 13.87s/it]

0.37909093499183655


 76%|█████████████████████████████▋         | 381/501 [1:28:46<28:00, 14.00s/it]

0.6807845085859299


 78%|██████████████████████████████▍        | 391/501 [1:31:06<25:42, 14.02s/it]

0.25458578392863274


 80%|███████████████████████████████▏       | 401/501 [1:33:26<23:22, 14.03s/it]

0.46401767805218697


 82%|███████████████████████████████▉       | 411/501 [1:35:46<20:56, 13.97s/it]

0.33175375685095787


 84%|████████████████████████████████▊      | 421/501 [1:38:06<18:34, 13.93s/it]

0.4047399237751961


 86%|█████████████████████████████████▌     | 431/501 [1:40:25<16:12, 13.90s/it]

0.6077894405461848


 88%|██████████████████████████████████▎    | 441/501 [1:42:44<13:52, 13.88s/it]

0.32870616763830185


 90%|███████████████████████████████████    | 451/501 [1:45:03<11:36, 13.92s/it]

0.5758156292140484


 92%|███████████████████████████████████▉   | 461/501 [1:47:22<09:16, 13.91s/it]

0.39727046340703964


 94%|████████████████████████████████████▋  | 471/501 [1:49:41<06:57, 13.91s/it]

0.30676137283444405


 96%|█████████████████████████████████████▍ | 481/501 [1:52:00<04:38, 13.90s/it]

0.30618688464164734


 98%|██████████████████████████████████████▏| 491/501 [1:54:19<02:18, 13.89s/it]

0.5570314861834049


100%|███████████████████████████████████████| 501/501 [1:56:38<00:00, 13.97s/it]

0.09460968151688576





In [7]:
model=torch.load('fr.pt')
model=model.to('cuda')

In [9]:
a,p,n = cld2[20]
a,p,n = a.reshape((1,3,112,92)), p.reshape((1,3,112,92)), n.reshape((1,3,112,92))
vals=torch.cat((a,p,n),dim=0).to('cuda')
res=model(vals).detach()
print(res.shape)
res[0].shape
print(f'a/p: {nn.functional.mse_loss(res[0],res[1])}')
print(f'a/n: {nn.functional.mse_loss(res[0],res[2])}')
print(f'p/n: {nn.functional.mse_loss(res[1],res[2])}')


torch.Size([3, 256])
a/p: 0.014126414433121681
a/n: 0.5725901126861572
p/n: 0.5608456134796143


In [29]:
gallery=[]
for fn in glob('/home/scs/taylor/faculty/onlySD312/*.jpg'):
    img = transform(read_image(fn, mode=ImageReadMode.RGB))
    gallery.append(img)

In [37]:
embeds=[]
for img in gallery:
    shape = img.shape
    img=img.reshape((1,shape[0],shape[1],shape[2]))
    embeds.append(model(img.to('cuda')).detach().cpu().numpy())

In [38]:
embeds = np.concatenate(embeds,axis=0)
embeds.shape

(39, 256)