In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd drive/My \Drive/Acad/ADS/Project/

In [None]:
!pip install timm

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torch.utils.data import DataLoader, TensorDataset

class SwinTransformerUNet(nn.Module):
    def __init__(self, num_classes, img_size=224, pretrained=True):
        super().__init__()

        self.swin_transformer = timm.create_model('swin_tiny_patch4_window7_224',
                                                  pretrained=pretrained,
                                                  num_classes=0,
                                                  in_chans=3)

        self.upsample = nn.Upsample(size=(img_size, img_size), mode='bilinear', align_corners=True)

        self.decoder = nn.Sequential(
            nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, num_classes, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        x = self.swin_transformer.forward_features(x)

        x = x.permute(0, 3, 1, 2)

        x = self.upsample(x)
        x = self.decoder(x)

        return x

device = "cuda"
model = SwinTransformerUNet(num_classes=1, img_size=224).to(device)

In [None]:
'''
Notebook to train Swin Transformer model
'''

# Import required libraries
import os
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch import optim
import torchvision
from PIL import Image
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

# Load training data
# Labels (Outputs)
x_trainHR = np.load('./Data/trans/train_labels.npy').astype(np.float32)
# Images (Conditions)
x_trainLR = np.load('./Data/trans/train_images.npy').astype(np.float32)
x_trainHR = torch.Tensor(x_trainHR)
x_trainLR = torch.Tensor(x_trainLR)
# Print data dimensions
print(x_trainHR.shape)
print(x_trainLR.shape)

# Create dataset and dataloader for efficient data loading and batching
dataset = TensorDataset(x_trainHR,x_trainLR)
dataloader = DataLoader(dataset, batch_size=5)

l = len(dataloader)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
epochs = 100

# Define your training loop
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for labels, inputs in tqdm(dataloader):
        if torch.cuda.is_available():
            inputs, labels = inputs.cuda(), labels.cuda()

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch+1} Loss: {running_loss/len(dataloader)}')

    # Save model weights
    torch.save(model, os.path.join("Weights", f"ViT_ckpt_1.pt"))