In [2]:
import os
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from lightly.loss import NTXentLoss
from lightly.models.modules import (
    NNCLRPredictionHead,
    NNCLRProjectionHead,
    NNMemoryBankModule,
)
import wandb

In [7]:

data_path = os.path.join(os.getcwd(), "features_mae_large")
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

### Understand Data

In [3]:
test_center_crop = torch.load(data_path + "mae_l23_cls_test_centercrop.th")
print(test_center_crop.shape)
del test_center_crop

torch.Size([50000, 1024])


In [4]:
train_center_crop = torch.load(data_path + "mae_l23_cls_train_centercrop.th")
print(train_center_crop.shape)
del train_center_crop

torch.Size([1281167, 1024])


In [5]:
v1_seed0 = torch.load(data_path + "tensors_v1_seed_0.th")
print(v1_seed0.shape)
del v1_seed0

torch.Size([1281167, 1024])


In [6]:
v1_seed2 = torch.load(data_path + "tensors_v1_seed_2.th")
print(v1_seed2.shape)
del v1_seed2

torch.Size([1281167, 1024])


In [7]:
v1_seed3 = torch.load(data_path + "tensors_v1_seed_3.th")
print(v1_seed3.shape)
del v1_seed3

torch.Size([1281167, 1024])


In [8]:
v1_seed4 = torch.load(data_path + "tensors_v1_seed_4.th")
print(v1_seed4.shape)
del v1_seed4

torch.Size([1281167, 1024])


In [9]:
v1_seed5 = torch.load(data_path + "tensors_v1_seed_5.th")
print(v1_seed5.shape)
del v1_seed5

torch.Size([1281167, 1024])


In [10]:
v1_seed6 = torch.load(data_path + "tensors_v1_seed_6.th")
print(v1_seed6.shape)
del v1_seed6

torch.Size([1281167, 1024])


In [11]:
v2_seed0 = torch.load(data_path + "tensors_v2_seed_0.th")
print(v2_seed0.shape)
del v2_seed0

torch.Size([1281167, 1024])


In [12]:
v2_seed1 = torch.load(data_path + "tensors_v2_seed_1.th")
print(v2_seed1.shape)
del v2_seed1

torch.Size([1281167, 1024])


In [13]:
v2_seed2 = torch.load(data_path + "tensors_v2_seed_2.th")
print(v2_seed2.shape)
del v2_seed2

torch.Size([1281167, 1024])


In [14]:
v2_seed3 = torch.load(data_path + "tensors_v2_seed_3.th")
print(v2_seed3.shape)
del v2_seed3

torch.Size([1281167, 1024])


In [15]:
v2_seed4 = torch.load(data_path + "tensors_v2_seed_4.th")
print(v2_seed4.shape)
del v2_seed4

torch.Size([1281167, 1024])


In [16]:
v2_seed5 = torch.load(data_path + "tensors_v2_seed_5.th")
print(v2_seed5.shape)
del v2_seed5

torch.Size([1281167, 1024])


In [17]:
v2_seed6 = torch.load(data_path + "tensors_v2_seed_6.th")
print(v2_seed6.shape)
del v2_seed6

torch.Size([1281167, 1024])


In [18]:
test_labels = torch.load(data_path + "test-labels.th")
print(test_labels.shape)
print(test_labels.unique().shape)
del test_labels

torch.Size([50000])
torch.Size([1000])


In [19]:
train_labels = torch.load(data_path + "train-labels.th")
print(train_labels.shape)
print(train_labels.unique().shape)
del train_labels

torch.Size([1281167])
torch.Size([1000])


### Model Stuff

In [3]:
class MAEDataset(torch.utils.data.Dataset):
    def __init__(self, view1_path, view2_path):

        self.data1 = torch.load(view1_path)
        self.data2 = torch.load(view2_path)
        assert self.data1.shape == self.data2.shape, "view1 and view2 must have the same shape"
        self.length = self.data1.shape[0]

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.data1[idx], self.data2[idx]

In [4]:
class NNCLRHead(nn.Module):
    def __init__(self, project_hidden_dim, project_output_dim, 
                 predict_hidden_dim, predict_output_dim):
        super().__init__()

        self.projection_head = NNCLRProjectionHead(1024, # input_dim 
                                                   project_hidden_dim, # hidden_dim 
                                                   project_output_dim) # output_dim
        self.prediction_head = NNCLRPredictionHead(project_output_dim, # input_dim 
                                                   predict_hidden_dim, # hidden_dim
                                                   predict_output_dim) # output_dim

    def forward(self, x):
        z = self.projection_head(x)
        p = self.prediction_head(z)
        z = z.detach()
        return z, p

In [5]:
# Hyperparameters
EPOCHS = 10
BATCH_SIZE = 1024
LR = 1e-4
TEMPERATURE = 0.15
QUEUE_SIZE = 65536
WEIGHT_DECAY = 1e-5
PROJECT_HIDDEN_DIM = 2048
PROJECT_OUTPUT_DIM = 256
PREDICTION_HIDDEN_DIM = 4096
PREDICTION_OUTPUT_DIM = 256

# Training

In [9]:
view1_path = os.path.join(data_path, "tensors_v1_seed_0.th").replace("\\", "/")
view2_path = os.path.join(data_path, "tensors_v2_seed_0.th").replace("\\", "/")

c:/Users/lopez/Projects/jku-pr/features_mae_large/tensors_v1_seed_0.th


In [6]:
model = NNCLRHead(PROJECT_HIDDEN_DIM, 
                  PROJECT_OUTPUT_DIM, 
                  PREDICTION_HIDDEN_DIM, 
                  PREDICTION_OUTPUT_DIM)
model.to(device)
memory_bank = NNMemoryBankModule(size=(QUEUE_SIZE, PREDICTION_OUTPUT_DIM))
memory_bank.to(device)
criterion = NTXentLoss(temperature=TEMPERATURE, memory_bank_size=(QUEUE_SIZE, PREDICTION_OUTPUT_DIM))
optimizer = torch.optim.SGD(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

dataset = MAEDataset(view1_path, view2_path)
dataloader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=BATCH_SIZE,
    shuffle=True,
)


print("Starting Training")
for epoch in range(EPOCHS):
    total_loss = 0
    for x0, x1 in dataloader:
        x0, x1 = x0.to(device), x1.to(device)
        z0, p0 = model(x0)
        z1, p1 = model(x1)
        z0 = memory_bank(z0, update=False) # update can be True for z0 xor z1
        z1 = memory_bank(z1, update=True)
        loss = 0.5 * (criterion(z0, p1) + criterion(z1, p0))
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

Starting Training
epoch: 00, loss: 10.61001
epoch: 01, loss: 8.99533
epoch: 02, loss: 7.80246
epoch: 03, loss: 7.22053
epoch: 04, loss: 6.89556
epoch: 05, loss: 6.68435
epoch: 06, loss: 6.53453
epoch: 07, loss: 6.42043
epoch: 08, loss: 6.32916
epoch: 09, loss: 6.25517


In [7]:
torch.cuda.empty_cache()

### Quantitative metrics