In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from Utils.masking import MaskGenerator, visualise_mask
from Utils.dataset import PreloadedDataset
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import math
from tqdm import tqdm
import time

In [20]:
from Utils.masking import MaskGenerator, apply_masks

mask_generator = MaskGenerator(input_size=4, npred=2, device='cpu', min_keep=1)
context_masks, target_masks = mask_generator.sample_masks(1)
print(context_masks)
target_masks

h: 1, w: 2
h: 1, w: 2
h: 3, w: 3
[tensor([[ 5,  9, 10, 11, 13, 14, 15]])]


[tensor([[6, 7]]), tensor([[1, 2]])]

In [23]:
from Utils.nn.resnet_encoder import resnet18
from Utils.nn.conv_mixer import ConvMixer

resnet = resnet18((1, 224, 224))
convmixer = ConvMixer(dim=512, depth=20)
print(f"Number of parameters in ResNet18: {sum(p.numel() for p in resnet.parameters() if p.requires_grad)/1e6}M")
print(f"Number of parameters in ConvMixer: {sum(p.numel() for p in convmixer.parameters() if p.requires_grad)/1e6}M")

Number of parameters in ResNet18: 14.650304M
Number of parameters in ConvMixer: 5.570048M


In [24]:
x = torch.randn(1, 1, 128, 128)
convmixer(x).shape


torch.Size([1, 512])

[tensor([[3, 7]]), tensor([[2, 6]])]

In [2]:
from Models import BYOL, iJEPA
from Utils.utils import get_optimiser
from Utils.cfg import mnist_cfg
from Utils.functional import cosine_schedule

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = datasets.MNIST(root='../Datasets/', train=True, download=True, transform=transforms.ToTensor())
train_set, val_set = torch.utils.data.random_split(dataset, [48000, 12000])
train_set = PreloadedDataset.from_dataset(train_set, None, device)
# train_set.transform = transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10)

                                                        

In [3]:
model = iJEPA(1).to(device)
teacher = model.copy()
img1 = train_set[0][0].unsqueeze(0)
model.train_step(img1, teacher)

Mask generator says: "Valid mask not found, decreasing acceptable-regions [1]"
Mask generator says: "Valid mask not found, decreasing acceptable-regions [2]"
Mask generator says: "Valid mask not found, decreasing acceptable-regions [3]"
Mask generator says: "Valid mask not found, decreasing acceptable-regions [4]"
Mask generator says: "Valid mask not found, decreasing acceptable-regions [5]"
Mask generator says: "Valid mask not found, decreasing acceptable-regions [6]"
Mask generator says: "Valid mask not found, decreasing acceptable-regions [7]"
Mask generator says: "Valid mask not found, decreasing acceptable-regions [8]"
Mask generator says: "Valid mask not found, decreasing acceptable-regions [9]"
Mask generator says: "Valid mask not found, decreasing acceptable-regions [10]"
Mask generator says: "Valid mask not found, decreasing acceptable-regions [11]"
Mask generator says: "Valid mask not found, decreasing acceptable-regions [12]"
Mask generator says: "Valid mask not found, decre

KeyboardInterrupt: 

In [29]:
import torch
import torch.nn as nn
from Utils.nn.nets import Encoder28

# class BYOL(nn.Module):
#     def __init__(self, in_features):
#         super().__init__()
#         self.in_features = in_features

#         self.num_features = 256
#         self.encoder = Encoder28(256)

#         self.project = nn.Sequential(
#             nn.Linear(self.num_features, 1024, bias=False),
#             nn.BatchNorm1d(1024),
#             nn.ReLU(),
#             nn.Linear(1024, 256, bias=False),
#         )

#         self.predict = nn.Sequential(
#             nn.Linear(256, 512, bias=False),
#             nn.BatchNorm1d(512),
#             nn.ReLU(),
#             nn.Linear(512, 256, bias=False),
#         )

#     def forward(self, x):
#         return self.encoder(x)
    
#     def copy(self):
#         model = BYOL(self.in_features).to(next(self.parameters()).device)
#         model.load_state_dict(self.state_dict())
#         return model

def train(
        online_model,
        optimiser,
        train_dataset,
        num_epochs,
        batch_size,
        augmentation,
):
    device = next(online_model.parameters()).device

#============================== Online Model Learning Parameters ==============================
    # LR schedule, warmup then cosine
    base_lr = 3e-5 * batch_size / 256
    end_lr = 1e-6
    warm_up_lrs = torch.linspace(0, base_lr, 10)
    cosine_lrs = cosine_schedule(base_lr, end_lr, num_epochs-10)
    lrs = torch.cat([warm_up_lrs, cosine_lrs])
    assert len(lrs) == num_epochs

    # WD schedule, cosine 
    start_wd = 0.04
    end_wd = 0.4
    wds = cosine_schedule(start_wd, end_wd, num_epochs)
    
#============================== Target Model Learning Parameters ==============================
    # Initialise target model
    target_model = online_model.copy()
    # EMA schedule, cosine
    start_tau=0.996
    end_tau = 0.999
    taus = cosine_schedule(start_tau, end_tau, num_epochs)

# ============================== Data Handling ==============================
    # Initialise dataloaders for training and validation
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# ============================== Training Stuff ==============================
    # Initialise training variables
    last_train_loss = -1
    postfix = {}

# ============================== Training Loop ==============================
    train_losses = []
    for epoch in range(num_epochs):
        online_model.train()
        target_model.train()
        train_dataset.apply_transform(batch_size=batch_size)
        loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
        loop.set_description(f'Epoch [{epoch}/{num_epochs}]')
        if epoch > 0:
            loop.set_postfix(postfix)

        # Update lr
        for param_group in optimiser.param_groups:
            param_group['lr'] = lrs[epoch].item()
            if param_group['weight_decay'] != 0:
                param_group['weight_decay'] = wds[epoch].item()
                
        # Training Pass
        epoch_train_losses = torch.zeros(len(train_loader), device=device)
        for i, (images, _) in loop:
            loss = online_model.loss(images, target_model)

            loss.backward()
            optimiser.step()
            optimiser.zero_grad(set_to_none=True)

            # Update target model
            with torch.no_grad():
                for o_param, t_param in zip(online_model.parameters(), target_model.parameters()):
                    t_param.data = taus[epoch] * t_param.data + (1 - taus[epoch]) * o_param.data

            epoch_train_losses[i] = loss.detach()
        
        
        last_train_loss = epoch_train_losses.mean().item()
        postfix = {'train_loss': last_train_loss}
        train_losses.append(last_train_loss)
        if (epoch+1) % 10 == 0:
            print(f"Epoch {epoch+1} train loss: {last_train_loss}")
        if (epoch+1) == 50: 
            return train_losses

def get_optimiser_old(model, optimiser, lr, wd, exclude_bias=True, exclude_bn=True, momentum=0.9, betas=(0.9, 0.999)):
    non_decay_parameters = []
    decay_parameters = []   
    for n, p in model.named_parameters():
        if exclude_bias and 'bias' in n:
            non_decay_parameters.append(p)
        elif exclude_bn and 'bn' in n:
            non_decay_parameters.append(p)
        else:
            decay_parameters.append(p)
    non_decay_parameters = [{'params': non_decay_parameters, 'weight_decay': 0.0}]
    decay_parameters = [{'params': decay_parameters}]

    assert optimiser in ['AdamW', 'SGD'], 'optimiser must be one of ["AdamW", "SGD"]'
    if optimiser == 'AdamW':
        if momentum != 0.9:
            print('Warning: AdamW does not accept momentum parameter. Ignoring it. Please specify betas instead.')
        optimiser = torch.optim.AdamW(decay_parameters + non_decay_parameters, lr=lr, weight_decay=wd, betas=betas)
    elif optimiser == 'SGD':
        if betas != (0.9, 0.999):
            print('Warning: SGD does not accept betas parameter. Ignoring it. Please specify momentum instead.')
        optimiser = torch.optim.SGD(decay_parameters + non_decay_parameters, lr=lr, weight_decay=wd, momentum=momentum)
    
    return optimiser


def train(
        model,
        optimiser,
        train_dataset,
        val_dataset,
        writer,
        cfg:dict,
):

    device = cfg['compute_device'] + ':' + str(cfg['ddp_rank'])

#============================== Online Model Learning Parameters ==============================
    # # LR schedule, warmup then cosine
    # assert cfg['warmup'] + cfg['flat'] <= cfg['num_epochs'], f'warmup must be less than or equal to num_epochs, got {cfg["warmup"]} and {cfg["num_epochs"]}'
    # start_lr = cfg['start_lr'] * cfg['batch_size'] / 256
    # end_lr = cfg['end_lr'] * cfg['batch_size'] / 256
    # warm_up_lrs = torch.linspace(0, start_lr, cfg['warmup']+1)[1:]
    # if cfg['num_epochs'] > cfg['warmup']+cfg['flat']:
    #     if cfg['decay_lr']:
    #         cosine_lrs = cosine_schedule(start_lr, end_lr, cfg['num_epochs']-cfg['warmup']-cfg['flat'])
    #     else:
    #         cosine_lrs = torch.ones(cfg['num_epochs']-cfg['warmup']-cfg['flat']) * start_lr
    #     lrs = torch.cat([warm_up_lrs, cosine_lrs])
    # if cfg['flat'] > 0:
    #     lrs = torch.cat([lrs, torch.ones(cfg['flat']) * cfg['end_lr']])
    # assert len(lrs) == cfg['num_epochs']

    # # WD schedule, cosine 
    # wds = cosine_schedule(cfg['start_wd'], cfg['end_wd'], cfg['num_epochs'])
    # LR schedule, warmup then cosine
    base_lr = 3e-5 * cfg['batch_size'] / 256
    end_lr = 1e-6
    warm_up_lrs = torch.linspace(0, base_lr, 10)
    cosine_lrs = cosine_schedule(base_lr, end_lr, cfg['num_epochs']-10)
    lrs = torch.cat([warm_up_lrs, cosine_lrs])
    assert len(lrs) == cfg['num_epochs']

    # WD schedule, cosine 
    start_wd = 0.04
    end_wd = 0.4
    wds = cosine_schedule(start_wd, end_wd, cfg['num_epochs'])

#============================== Target Model Learning Parameters ==============================
    if cfg['has_teacher']:
        print('Has teacher')
        # Initialise target model
        teacher = model.copy()
        # teacher.eval()
        # EMA schedule, cosine
        taus = cosine_schedule(cfg['start_tau'], cfg['end_tau'], cfg['num_epochs'])
    else:
        print('No teacher')
        teacher = None
    # Initialise target model
    # teacher = model.copy()
    # # EMA schedule, cosine
    # start_tau=0.996
    # end_tau = 0.999
    # taus = cosine_schedule(start_tau, end_tau, cfg['num_epochs'])

# ============================== Data Handling ==============================
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg['batch_size'], shuffle=True)

# ============================== Training Stuff ==============================

    # Initialise training variables
    postfix = {}

# ============================== Training Loop ==============================
    for epoch in range(cfg['num_epochs']):
        model.train()

        train_dataset.apply_transform(batch_size=cfg['batch_size'])

        # Update lr and wd
        for param_group in optimiser.param_groups:
            param_group['lr'] = lrs[epoch].item()
            if param_group['weight_decay'] != 0:
                param_group['weight_decay'] = wds[epoch].item()
        
        # Training Pass
        epoch_train_losses = torch.zeros(len(train_loader), device=device)
        epoch_train_norms = torch.zeros(len(train_loader), device=device)
        if cfg['master_process'] and cfg['local']:
            loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
            loop.set_description(f'Epoch [{epoch}/{cfg["num_epochs"]}]')
            if epoch > 0:
                loop.set_postfix(postfix)
        else:
            loop = enumerate(train_loader)
        for i, data in loop:

            images2 = None
            actions = None
            if cfg['dataset'] == 'mnist':
                images1, _ = data
            else:
                raise NotImplementedError(f'Dataset {cfg["dataset"]} not implemented')

            if images1.device != device:
                images1 = images1.to(device)
                if images2 is not None:
                    images2 = images2.to(device)
                if actions is not None:
                    actions = actions.to(device)

            loss = model.loss(
                img1=images1, 
                img2=images2, 
                actions=actions, 
                teacher=teacher, 
                epoch=epoch
            )
        
            loss.backward()

            epoch_train_norms[i] = torch.norm(torch.stack([torch.norm(p.grad) for p in model.parameters() if p.grad is not None]), 2).detach()

            optimiser.step()

            optimiser.zero_grad(set_to_none=True)

            if cfg['has_teacher']:
                # Update target model
                with torch.no_grad():
                    for o_param, t_param in zip(model.parameters(), teacher.parameters()):
                        t_param.data = taus[epoch] * t_param.data + (1 - taus[epoch]) * o_param.data

            epoch_train_losses[i] = loss.detach()

        postfix = {'train_loss': epoch_train_losses.mean().item()}

In [30]:

epochs = 250
batch_size = 256
dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)

model = BYOL(1).to(device)
cfg = mnist_cfg('byol_test', 'byol', 'BYOL')[0]
teacher = model.copy()
optimiser = get_optimiser(model, cfg)


In [31]:
train_losses = train(model, optimiser, train_set, val_set, None, cfg)
plt.plot(train_losses)
plt.show()

Has teacher


Epoch [0/250]:   0%|          | 0/188 [00:00<?, ?it/s]

                                                                                    

KeyboardInterrupt: 

In [7]:
start_lr = 3e-5
end_lr = 1e-6
warmup_lrs = torch.linspace(0, start_lr, 11)[1:]
lrs = cosine_schedule(start_lr, end_lr, epochs-10)
lrs = torch.cat([warmup_lrs, lrs])

wds = cosine_schedule(0.04, 0.4, epochs)
taus = cosine_schedule(0.996, 1.0, 250)
losses = []

start_time = time.time()
model.train()
teacher.eval()
for e in range(epochs):

    # train_set.apply_transform(batch_size=batch_size)
    # Update lr and wd
    for param_group in optimiser.param_groups:
        param_group['lr'] = lrs[e].item()
        if param_group['weight_decay'] != 0:
            param_group['weight_decay'] = wds[e].item()
    
    loop = tqdm(dataloader, total=len(dataloader), leave=False)
    loop.set_description(f'Epoch {e+1}/{epochs}')
    if e > 0:
        loop.set_postfix(loss=losses[-1], time_elapsed=time.strftime('%H:%M:%S', time.gmtime(time.time() - start_time)))

    epoch_losses = []
    for i, (images, _) in enumerate(loop):
        img1, img2 = augmentation(images), augmentation(images)

        with torch.no_grad():
            with torch.autocast(device_type=img1.device.type, dtype=torch.bfloat16):
                y1_t, y2_t = teacher(img1), teacher(img2)
                z1_t, z2_t = teacher.project(y1_t), teacher.project(y2_t)
                z1_t, z2_t = F.normalize(z1_t, dim=-1), F.normalize(z2_t, dim=-1)

        with torch.autocast(device_type=img1.device.type, dtype=torch.bfloat16):
            y1_o, y2_o = model(img1), model(img2)
            z1_o, z2_o = model.project(y1_o), model.project(y2_o)
            p1_o, p2_o = model.predict(z1_o), model.predict(z2_o)
            p1_o, p2_o = F.normalize(p1_o, dim=-1), F.normalize(p2_o, dim=-1)
        
            loss = 0.5 * (F.mse_loss(p1_o, z2_t) + F.mse_loss(p2_o, z1_t))

        loss.backward()
        optimiser.step()
        optimiser.zero_grad(set_to_none=True)
        epoch_losses.append(loss.item())
    loop.close()
    losses.append(sum(epoch_losses) / len(epoch_losses))

    with torch.no_grad():
        for s_param, t_param in zip(model.parameters(), teacher.parameters()):
            t_param.data = t_param.data * taus[e] + s_param.data * (1 - taus[e])
    
    if (e+1) % 10 == 0:
        print(f'Epoch {e+1}/{epochs} - Loss: {losses[-1]}')

    if (e+1) == 50:
        break

plt.plot(losses)
plt.show()

                                                                                                     

KeyboardInterrupt: 