In [1]:
DEVICE: str = "cuda:1"
BATCH_SIZE: int = 128
PROJECTION_SIZE: int = 256
MASKING_RATIO: float = 0.2
EPOCHS: int = 100
BT_LAMBDA: float = 5e-3
INSTANCE_LOSS: str = "simclr" # "barlow_twins" or "simclr"
CLUSTER_LOSS: str = "simclr" # "barlow_twins" or "simclr"
REPETITIONS: int = 5
NOISE: str = "mixed" # "swap_noise", "gaussian", "mixed", "zero"
TAG: str = "default"

In [2]:
BATCH_SIZE = int(BATCH_SIZE)
PROJECTION_SIZE = int(PROJECTION_SIZE)
MASKING_RATIO = float(MASKING_RATIO)
EPOCHS = int(EPOCHS)

In [3]:
params = {
    'learning_rate': 1e-3,
    'eps': 1e-7,
    'projection_size': PROJECTION_SIZE,
    'n_layers': 3,
    '0_layer_size': 512,
    '1_layer_size': 256,
    '2_layer_size': 128,
    '3_layer_size': 128,
    'masking_ratio': MASKING_RATIO,
    'noise': NOISE,
}

In [4]:
import sys
sys.path.append("../")

In [5]:
import comet_ml

In [6]:
import numpy as np
import torch
from torch import nn
import tqdm
import torch.optim
from modules import network, contrastive_loss
from utils import yaml_config_hook, save_model
from torch.utils import data
from utils.load_dataset import load_dataset
from evaluation import evaluation
from utils.generate_noise import generate_noisy_xbar

In [7]:
train_dataset, test_dataset = load_dataset("MNIST")

dataset = data.ConcatDataset([train_dataset, test_dataset])
class_num = 10
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    num_workers=1,
)

## Initialize network

## Implement Barlow Twins loss

In [8]:
class BarlowTwinsLoss(nn.Module):
    def __init__(self, lbd) -> None:
        super().__init__()
        self.lbd = lbd
    
    def forward(self, z_a, z_b) -> torch.Tensor:
        z_a = nn.functional.normalize(z_a, dim=0)
        z_b = nn.functional.normalize(z_b, dim=0)
        c = torch.matmul(z_a.T, z_b) 
        invariance_loss = c - torch.eye(c.shape[0], device=c.device)
        loss = torch.sum(invariance_loss.diagonal() ** 2)

        redundancy_loss = c**2
        redundancy_loss.diagonal().fill_(0)
        loss += self.lbd * torch.sum(redundancy_loss)
        return loss

In [9]:
bt_loss = BarlowTwinsLoss(BT_LAMBDA)

In [10]:
z_a = torch.randn(256, 128)
bt_loss(z_a, z_a)

tensor(0.3115)

## Prepare clustering evaluation

In [11]:
def cluster(model, data_loader):
    model.eval()
    accuracies = []
    for step, (x, y) in enumerate(data_loader):
        x = x.to(DEVICE)
        y = y.tolist()
        
        with torch.no_grad():
            y_pred = model.forward_cluster(x).cpu().detach().tolist()
            
        nmi, ari, f, acc = evaluation.evaluate(y, y_pred, 10)
        accuracies.append(acc)
    return np.mean(accuracies)

## Train the model

In [12]:
loss_device = torch.device(DEVICE)
if INSTANCE_LOSS == "barlow_twins":
    criterion_instance = BarlowTwinsLoss(BT_LAMBDA)
else:
    criterion_instance = contrastive_loss.InstanceLoss(BATCH_SIZE, 0.5, loss_device).to(
        loss_device)
    
if CLUSTER_LOSS == "barlow_twins":
    criterion_cluster = BarlowTwinsLoss(BT_LAMBDA)
else:
    criterion_cluster = contrastive_loss.ClusterLoss(class_num, 1.0, loss_device).to(loss_device)

In [13]:
final_accs = []
logged_params = {
    'batch_size': BATCH_SIZE,
    'masking_ratio': MASKING_RATIO,
    'noise': 'mixed',
    'bt_lambda': BT_LAMBDA,
    'projection_size': PROJECTION_SIZE,
    'epochs': EPOCHS,
    'instance_loss': INSTANCE_LOSS,
    'cluster_loss': CLUSTER_LOSS,
    'noise': NOISE,
}
print("Start training on device: {}".format(DEVICE))
print(logged_params)

for _ in range(REPETITIONS):
    experiment = comet_ml.Experiment(
        api_key="5AlQI5f2YzhHH2DLIYNOsuKzj",
        project_name="subtab_cluster",
        workspace="wwydmanski",
    )

    experiment.log_parameters(params)
    experiment.log_parameters(logged_params)

    experiment.log_code()

    model = network.Network(784, params, class_num)
    model = model.to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=params['learning_rate'], weight_decay=1e-3, betas=(0.9,0.999), eps=params['eps'])

    for epoch in tqdm.trange(EPOCHS):
        loss_epoch = 0
        loss_bt_epoch = 0
        loss_cluster_epoch = 0

        for step, (x, _) in enumerate(data_loader):
            optimizer.zero_grad()
            x_i = x.clone()
            x_j = x.clone()
            x_i = generate_noisy_xbar(x_i, params['noise'], params['masking_ratio'])
            x_j = generate_noisy_xbar(x_j, params['noise'], params['masking_ratio'])
            x_i = x_i.to(DEVICE)
            x_j = x_j.to(DEVICE) 
            z_i, z_j, c_i, c_j = model(x_i, x_j)
            
            loss_instance = criterion_instance(z_i, z_j)
            loss_cluster = criterion_cluster(c_i, c_j)

            loss = loss_instance + loss_cluster
            loss.backward()
            optimizer.step()

            loss_bt_epoch += loss_instance.item()
            loss_cluster_epoch += loss_cluster.item()
            loss_epoch += loss.item()

        acc = cluster(model, data_loader)
        experiment.log_metric("loss", loss_epoch)
        experiment.log_metric("acc", acc)
        experiment.log_metric("loss_bt", loss_bt_epoch / len(data_loader))
        experiment.log_metric("loss_cluster", loss_cluster_epoch / len(data_loader))
    final_accs.append(acc)



Start training on device: cuda:1
{'batch_size': 128, 'masking_ratio': 0.2, 'noise': 'mixed', 'bt_lambda': 0.005, 'projection_size': 256, 'epochs': 100, 'instance_loss': 'simclr', 'cluster_loss': 'simclr'}


COMET INFO: Experiment is live on comet.com https://www.comet.com/wwydmanski/subtab-cluster/6248cf41c82e43d4b074bcf1c5fabffc

  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
print(round(np.mean(final_accs), 3), "~", round(np.std(final_accs), 3))