In [20]:
import torch
import torch.nn.functional as F
import torchvision
import numpy as np
import torch
from sklearn.model_selection import cross_val_score
from sklearn.neighbors import KNeighborsClassifier
from torch import nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import tqdm
from torchvision.transforms import v2

In [21]:
torch.manual_seed(0)

<torch._C.Generator at 0x272feecfc50>

In [22]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using device:', device)

Using device: cuda


### define models and functions

In [23]:
def EMA(new, alpha=0.99, old=None):
    if old is None:
        return new
    else:
        return old * alpha + (1 - alpha) * new

In [24]:
def loss_fn(x, y):
   # L2 normalization
   x = F.normalize(x, dim=-1, p=2)
   y = F.normalize(y, dim=-1, p=2)
   return 2 - 2 * (x * y).sum(dim=-1)

In [25]:
class MLP(nn.Module):
    def __init__(self, input_dim=2048) -> None:
        super(MLP, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(),
            nn.Linear(4096, 256)
        )
    
    def forward(self, x):
        return self.net(x)

In [26]:
class TargetModel(nn.Module):
    def __init__(self) -> None:
        super(TargetModel, self).__init__()
        self.encoder = torchvision.models.resnet50()
        self.encoder.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.encoder.fc = nn.Identity()

        self.represent = MLP()
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.represent(x)
        return x

In [27]:
class OnlineModel(nn.Module):
    def __init__(self) -> None:
        super(OnlineModel, self).__init__()
        self.encoder = torchvision.models.resnet50()
        self.encoder.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.encoder.fc = nn.Identity()

        self.represent = MLP()

        self.predictor = MLP(input_dim=256)
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.represent(x)
        x = self.predictor(x)
        return x


### Load data + augmentations

In [28]:
transform = transforms.RandomApply([
    transforms.RandomHorizontalFlip(p=0.6),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.3, 0.1)], p=0.6),
    transforms.RandomGrayscale(p=0.6),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.6), 
    ]
)

In [29]:
dataset_train = datasets.CIFAR10(
    root='data',
    train=True,
    download=True,
    transform=transforms.Compose([transforms.ToTensor(),
                                  transform])
)

Files already downloaded and verified


In [30]:
type(dataset_train[0][0])

torch.Tensor

In [31]:
batch_size = 128

In [32]:
train_loader = DataLoader(
    dataset_train, 
    batch_size=batch_size,
    shuffle=True
)

### Init

In [33]:
target_model = TargetModel().to(device)
online_model = OnlineModel().to(device)

In [34]:
online_model.encoder

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [35]:
for p in target_model.parameters():
    p.require_grads = False

In [36]:
opt = torch.optim.Adam(online_model.parameters())

In [37]:
online_model.train()
target_model.train()

for epoch in range(20):
    print('Epoch: %d' % (epoch + 1))
    count = 0
    train_loss = 0
    for x1, y in tqdm.tqdm(train_loader):
        count += 1
        opt.zero_grad()

        x1 = x1.to(device).float()
        x2 = x1.clone().to(device)

        x1 = transform(x1)
        x2 = transform(x2)
        with torch.no_grad():
            x1_res = target_model(x1)
        x2_res = online_model(x2)

        loss = loss_fn(x1_res, x2_res).mean()
        # print(f"batch {count}: loss {loss.item()}")

        train_loss += loss.item()
        # update target model params
        with torch.no_grad():
            for online_params, target_params in zip(online_model.parameters(), target_model.parameters()):
                old_weight, up_weight = target_params.data, online_params.data
                target_params.data = EMA(new=up_weight, old=old_weight)
        loss.backward()
        opt.step()
        
    print(f"train_loss: {train_loss / count}")

Epoch: 1


100%|██████████| 391/391 [00:40<00:00,  9.57it/s]


train_loss: 0.18844862027889323
Epoch: 2


100%|██████████| 391/391 [00:42<00:00,  9.23it/s]


train_loss: 0.1476792116242144
Epoch: 3


100%|██████████| 391/391 [00:44<00:00,  8.76it/s]


train_loss: 0.1265060906834386
Epoch: 4


100%|██████████| 391/391 [00:43<00:00,  8.93it/s]


train_loss: 0.07757409511234069
Epoch: 5


100%|██████████| 391/391 [00:44<00:00,  8.75it/s]


train_loss: 0.05891699741458725
Epoch: 6


100%|██████████| 391/391 [00:40<00:00,  9.64it/s]


train_loss: 0.04504007154651691
Epoch: 7


100%|██████████| 391/391 [00:40<00:00,  9.70it/s]


train_loss: 0.03871504927073103
Epoch: 8


100%|██████████| 391/391 [00:40<00:00,  9.77it/s]


train_loss: 0.032526867140246475
Epoch: 9


100%|██████████| 391/391 [00:39<00:00,  9.78it/s]


train_loss: 0.02966230197707215
Epoch: 10


100%|██████████| 391/391 [00:40<00:00,  9.75it/s]


train_loss: 0.03129426334196192
Epoch: 11


100%|██████████| 391/391 [00:39<00:00,  9.88it/s]


train_loss: 0.02963981412999008
Epoch: 12


100%|██████████| 391/391 [00:40<00:00,  9.77it/s]


train_loss: 0.026079685543961537
Epoch: 13


100%|██████████| 391/391 [00:39<00:00,  9.78it/s]


train_loss: 0.023145706230855507
Epoch: 14


100%|██████████| 391/391 [00:40<00:00,  9.73it/s]


train_loss: 0.0236405571565375
Epoch: 15


100%|██████████| 391/391 [00:40<00:00,  9.77it/s]


train_loss: 0.02177644046285497
Epoch: 16


100%|██████████| 391/391 [00:40<00:00,  9.72it/s]


train_loss: 0.02070858079316023
Epoch: 17


100%|██████████| 391/391 [00:39<00:00,  9.79it/s]


train_loss: 0.019414314671474343
Epoch: 18


100%|██████████| 391/391 [00:39<00:00,  9.78it/s]


train_loss: 0.020144676125091514
Epoch: 19


100%|██████████| 391/391 [00:40<00:00,  9.77it/s]


train_loss: 0.016173908307724406
Epoch: 20


100%|██████████| 391/391 [00:39<00:00,  9.82it/s]

train_loss: 0.020199842570001816





In [38]:
torch.save(online_model.state_dict(), "./results/SSL/pretrained/online")