In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from Utils.masking import MaskGenerator, visualise_mask
from Utils.dataset import PreloadedDataset
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import math
from tqdm import tqdm
import time

In [2]:
from Models import AE, iGPA

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = datasets.MNIST(root='../Datasets/', train=True, download=True, transform=transforms.ToTensor())
train_set, val_set = torch.utils.data.random_split(dataset, [48000, 12000])
train_set = PreloadedDataset.from_dataset(train_set, None, device)
# train_set.transform = transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10)

                                                        

In [3]:
from Utils.nn.nets import Encoder28, Decoder28

class EncBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bn=True, pool=False):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2) if pool else nn.Identity()
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.should_bn = bn
    
    def forward(self, x):
        x = self.conv(x)
        x = self.pool(x)
        if self.should_bn:
            x = self.bn(x)
        x = self.relu(x)
        return x
    
class DecBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, upsample=False):
        super().__init__()
        self.convt = nn.ConvTranspose2d(in_channels, in_channels, kernel_size, stride, padding)
        self.upsample = nn.Upsample(scale_factor=2) if upsample else nn.Identity()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
    
    def forward(self, x):
        x = self.convt(x)
        x = self.upsample(x)
        x = self.conv(x)
        return x

class mnist_cnn_encoder(nn.Module):
    def __init__(self, num_features):
        super().__init__()
        self.enc_blocks = nn.ModuleList([
            EncBlock(1, 32, 3, 1, 1, pool=True),
            EncBlock(32, 64, 3, 1, 1, pool=True),
            EncBlock(64, 128, 3, 1, 0),
            EncBlock(128, 256, 3, 1, 0),
            EncBlock(256, num_features, 3, 1, 0, bn=False),
        ])
    
    def forward(self, x):
        for block in self.enc_blocks:
            x = block(x)
        return x.flatten(1)

class mnist_cnn_decoder(nn.Module):
    def __init__(self, num_features):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(num_features, 256, 3, 1),
            nn.ReLU(),

            nn.ConvTranspose2d(256, 128, 3, 3),
            nn.ReLU(),
            
            nn.ConvTranspose2d(128, 64, 3, 3),
            nn.ReLU(),
            
            nn.ConvTranspose2d(64, 32, 2, 1),
            nn.ReLU(),
            nn.Conv2d(32, 1, 3, 1, 1),
        )

    def forward(self, z):
        z = z.view(-1, 256, 1, 1)
        return self.decoder(z)



class HEPA(nn.Module):
    def __init__(self, in_features, num_actions):
        super().__init__()
        self.in_features = in_features
        self.num_actions = num_actions

        self.num_features = 256
        # self.encoder = mnist_cnn_encoder(self.num_features)
        self.encoder = Encoder28(256)
    
        self.action_encoder = nn.Sequential(
            nn.Linear(num_actions, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
        )

        # NO BATCHNORM
        self.transition = nn.Sequential(
            nn.Linear(self.num_features + 128, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, self.num_features)
        )

        #for Mnist (-1, 1, 28, 28)
        # self.decoder = mnist_cnn_decoder(self.num_features)
        self.decoder = Decoder28(256)

    def forward(self, x):
        return self.encoder(x)
    
    def predict(self, x, a=None):
        if a is None:
            a = torch.zeros(x.shape[0], self.num_actions, device=x.device)
        
        z = self.encoder(x)
        a = self.action_encoder(a)
        z_pred = self.transition(torch.cat([z, a], dim=1))
        pred = self.decoder(z_pred)
        return pred
    
    def copy(self):
        model = HEPA(self.in_features, self.num_actions, self.backbone).to(next(self.parameters()).device)
        model.load_state_dict(self.state_dict())
        return model

In [4]:
from Utils.utils import get_optimiser
from Utils.functional import cosine_schedule

epochs = 250
batch_size = 256
dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
start_lr = 3e-4
end_lr = 1e-6
warmup_lrs = torch.linspace(0, start_lr, 11)[1:]
lrs = cosine_schedule(start_lr, end_lr, epochs-10)
lrs = torch.cat([warmup_lrs, lrs])

wds = cosine_schedule(0.04, 0.4, epochs)

model = iGPA(1, 5).to(device)
# model = HEPA(1, 5).to(device)
# model = AE(1).to(device)
opt_cfg = {
    'optimiser': 'AdamW',
    'betas': (0.9, 0.999),
    'exclude_bias': True,
    'exclude_bn': True
}
optimiser = get_optimiser(model, opt_cfg)

In [5]:
import torchvision.transforms.v2.functional as F_v2

def augment(images, p):
    # Sample Action
    act_p = torch.rand(5) # whether to apply each augmentation
    angle = torch.rand(1).item() * 360 - 180 if act_p[0] < p else 0
    translate_x = torch.randint(-8, 9, (1,)).item() if act_p[1] < p else 0
    translate_y = torch.randint(-8, 9, (1,)).item() if act_p[2] < p else 0
    scale = torch.rand(1).item() * 0.5 + 0.75 if act_p[3] < p else 1.0
    shear = torch.rand(1).item() * 50 - 25 if act_p[4] < p else 0
    images_aug = F_v2.affine(images, angle=angle, translate=(translate_x, translate_y), scale=scale, shear=shear)
    action = torch.tensor([angle/180, translate_x/8, translate_y/8, (scale-1.0)/0.25, shear/25], dtype=torch.float32, device=images.device).unsqueeze(0).repeat(images.shape[0], 1)

    return images_aug, action

In [6]:
losses = []
start_time = time.time()
model.train()
for e in range(epochs):
    # train_set.apply_transform(batch_size=batch_size)
    # Update lr and wd
    for param_group in optimiser.param_groups:
        param_group['lr'] = lrs[e].item()
        if param_group['weight_decay'] != 0:
            param_group['weight_decay'] = wds[e].item()
    
    loop = tqdm(dataloader, total=len(dataloader), leave=False)
    loop.set_description(f'Epoch {e+1}/{epochs}')
    if e > 0:
        loop.set_postfix(loss=losses[-1], time_elapsed=time.strftime('%H:%M:%S', time.gmtime(time.time() - start_time)))

    epoch_losses = []
    for i, (images, _) in enumerate(loop):
        images_aug, action = augment(images, 0.25)
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
            preds = model.predict(images, action)
            loss = F.mse_loss(preds, images_aug, reduction='none').sum(-1).mean()

        loss.backward()
        optimiser.step()
        optimiser.zero_grad(set_to_none=True)
        epoch_losses.append(loss.item())
    loop.close()
    losses.append(sum(epoch_losses) / len(epoch_losses))
    
    if (e+1) % 10 == 0:
        print(f'Epoch {e+1}/{epochs} - Loss: {losses[-1]}')

    if (e+1) == 50:
        break

plt.plot(losses)
plt.show()

                                                                                                  

Epoch 10/250 - Loss: 1.831953390481624


                                                                                                  

Epoch 20/250 - Loss: 1.1900731873639085


                                                                                                  

KeyboardInterrupt: 