In [1]:
import torch
import torch.nn.functional as F
import torchvision
import torch
from torch import nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import tqdm
import matplotlib.pyplot as plt

In [2]:
plt.style.use("ggplot")

In [3]:
torch.manual_seed(0)

<torch._C.Generator at 0x18a2adcfc50>

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using device:', device)

Using device: cuda


### define models and functions

In [5]:
def EMA(new, alpha=0.99, old=None):
    if old is None:
        return new
    else:
        return old * alpha + (1 - alpha) * new

In [6]:
def loss_fn(x, y):
   # L2 normalization
   x = F.normalize(x, dim=-1, p=2)
   y = F.normalize(y, dim=-1, p=2)
   return 2 - 2 * (x * y).sum(dim=-1)

In [7]:
class MLP(nn.Module):
    def __init__(self, input_dim=2048) -> None:
        super(MLP, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(),
            nn.Linear(4096, 256)
        )
    
    def forward(self, x):
        return self.net(x)

In [8]:
class TargetModel(nn.Module):
    def __init__(self) -> None:
        super(TargetModel, self).__init__()
        self.encoder = torchvision.models.resnet50()
        self.encoder.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.encoder.fc = nn.Identity()
        self.encoder.maxpool = torch.nn.Identity()

        self.represent = MLP()
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.represent(x)
        return x

In [9]:
class OnlineModel(nn.Module):
    def __init__(self) -> None:
        super(OnlineModel, self).__init__()
        self.encoder = torchvision.models.resnet50()
        self.encoder.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.encoder.fc = nn.Identity()
        self.encoder.maxpool = torch.nn.Identity()

        self.represent = MLP()

        # self.predictor = MLP(input_dim=256)
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.represent(x)
        # x = self.predictor(x)
        return x


In [None]:
class BYOL(nn.Module):
    def __init__(self,
                 moving_average_decay=0.99) -> None:
        super(BYOL, self).__init__()

        self.student_model = OnlineModel()
        self.teacher_model = TargetModel()
        self.moving_average_decay = moving_average_decay
        self.student_predictor = MLP(input_dim=256)

    @torch.no_grad()
    def update_moving_average(self):
       assert self.teacher_model is not None, 'target encoder has not been created yet'
       for student_params, teacher_params in zip(self.student_model.parameters(), self.teacher_model.parameters()):
         old_weight, up_weight = teacher_params.data, student_params.data
         teacher_params.data = EMA(old=old_weight, new=up_weight, alpha=self.moving_average_decay)

    def forward(self,
                image1,
                image2):
       # student projections: backbone + MLP projection
       student_proj_one = self.student_model(image1)
       student_proj_two = self.student_model(image2)

       # additional student's MLP head called predictor
       student_pred_one = self.student_predictor(student_proj_one)
       student_pred_two = self.student_predictor(student_proj_two)

       with torch.no_grad():
           # teacher processes the images and makes projections: backbone + MLP
           teacher_proj_one = self.teacher_model(image1).detach_()
           teacher_proj_two = self.teacher_model(image2).detach_()


       loss_one = loss_fn(student_pred_one, teacher_proj_one)
       loss_two = loss_fn(student_pred_two, teacher_proj_two)       

       return (loss_one + loss_two).mean()

### Load data + augmentations

In [10]:
transform = transforms.Compose([
                                transforms.RandomHorizontalFlip(p=0.5),
                                transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.3, 0.1)], p=0.5),
                                transforms.RandomGrayscale(p=0.5),
                                transforms.RandomApply([transforms.GaussianBlur(3)], p=0.5), 
                                ])

In [11]:
dataset_train = datasets.CIFAR10(
    root='data',
    train=True,
    download=True,
    transform=transforms.Compose([transforms.ToTensor()])
)

Files already downloaded and verified


In [12]:
batch_size = 128

In [13]:
train_loader = DataLoader(
    dataset_train, 
    batch_size=batch_size,
    shuffle=True
)

### Train

In [None]:
torch.cuda.empty_cache()

In [14]:
byol = BYOL()
opt = torch.optim.Adam(byol.parameters(), lr=0.003)
epochs = 20
byol.train()

In [15]:
# for p in target_model.parameters():
#     p.require_grads = False

In [18]:
losses = []
for epoch in range(epochs):
    print('Epoch: %d' % (epoch + 1))
    count = 0
    train_loss = 0
    for x, y in tqdm.tqdm(train_loader):
        count += 1
        opt.zero_grad()

        x1 = transform(x).to(device).float()
        x2 = transform(x).to(device).float()

        loss = byol(x1, x2)
        loss.backward()
        opt.step()
        byol.update_moving_average()

        train_loss += loss.item()
        
    losses.append(train_loss / count)
    print(f"train_loss: {train_loss / count}")
    torch.save(byol.student_model.encoder.state_dict(), f"./pretrained_feature_extractors/feature_extractor_{epoch+1}")


Epoch: 1


100%|██████████| 391/391 [01:36<00:00,  4.07it/s]


train_loss: 0.03737498081677482
Epoch: 2


100%|██████████| 391/391 [01:36<00:00,  4.04it/s]


train_loss: 0.0019414913358257325
Epoch: 3


100%|██████████| 391/391 [01:35<00:00,  4.10it/s]


train_loss: 0.000578703972852796
Epoch: 4


100%|██████████| 391/391 [01:33<00:00,  4.16it/s]


train_loss: 0.0001380612016255048
Epoch: 5


100%|██████████| 391/391 [01:34<00:00,  4.13it/s]


train_loss: 3.2501262820684865e-05
Epoch: 6


100%|██████████| 391/391 [01:34<00:00,  4.16it/s]


train_loss: 9.236966386134506e-06
Epoch: 7


100%|██████████| 391/391 [01:35<00:00,  4.09it/s]


train_loss: 1.8549330362798627e-06
Epoch: 8


100%|██████████| 391/391 [01:35<00:00,  4.08it/s]


train_loss: 8.169756940113358e-07
Epoch: 9


100%|██████████| 391/391 [01:36<00:00,  4.05it/s]


train_loss: 5.285505710355461e-07
Epoch: 10


100%|██████████| 391/391 [01:36<00:00,  4.04it/s]


train_loss: 4.894898065825557e-07


In [26]:
plt.plot(losses, label='train_loss')
plt.legend()
plt.savefig(f"./pretrained_feature_extractors/train_loss", bbox_inches='tight')
plt.clf()

<Figure size 640x480 with 0 Axes>

In [27]:
torch.cuda.empty_cache()