Initial CNN archtecture

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import os
import random
import torchvision.transforms as T

In [2]:
# --- Custom Dataset for .npy SAR images ---
class NPYContrastiveDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.file_paths = sorted([os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.npy')])
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        arr = np.load(self.file_paths[idx]).astype(np.float32)
        arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-8)
        arr = np.expand_dims(arr, axis=0)  # [1, H, W]
        img = torch.from_numpy(arr)

        if self.transform:
            return self.transform(img)
        return img, img

In [4]:
# initial CNN model placeholder
class simple_cnn(nn.Module):
    def __init__(self):
        super().__init__()
        self.c1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.p1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.c2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.p2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.c3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.p3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.act = nn.ReLU()
        self.avg = nn.AdaptiveAvgPool2d((1, 1))



    def forward(self, x):
        x = self.act(self.c1(x))
        x = self.p1(x)
        x = self.act(self.c2(x))
        x = self.p2(x)  
        x = self.act(self.c3(x))
        x = self.p3(x)
        x = self.avg(x)
        return x.view(x.size(0), -1)  # Flatten the output
        



In [6]:
# Data augmentation placeholder 
class contrastive_transforms:
    def __init__(self):
        self.transform = nn.Identity([
            T.RandomHorizontalFlip(),
            T.RandomRotation(10),
            T.RandomResizedCrop(32),
            T.ColorJitter(0.4, 0.4, 0.4, 0.1),
            T.ToTensor(),
            T.Normalize((0.5,), (0.5,))   # need to normalise custom dataset to have mean 0 and std 1
        ])

    def __call__(self, x):   # calls the funtion when an instance of the class is called
        return self.transform(x), self.transform(x)

In [7]:
# --- Load Custom Dataset ---
folder = r"C:\Users\Matthew.Barrett\Downloads\test_data\sigma0_arrays"
train_dataset = NPYContrastiveDataset(folder_path=folder, transform=contrastive_transforms())
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=2, drop_last=True)

In [8]:
# Initial SimCLR model placeholder
class SimCLR(nn.Module):
    def __init__(self, base_encoder, out_dim=64):
        super().__init__()
        self.encoder = base_encoder
        self.projector = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, out_dim)
        )

    def forward(self, x):
        h = self.encoder(x)
        z = self.projector(h)
        return z

In [9]:
# initial loss function placeholder

# --- NT-Xent Loss ---
def nt_xent_loss(z_i, z_j, temperature=0.5):
    z = torch.cat([z_i, z_j], dim=0)
    z = F.normalize(z, dim=1)
    similarity_matrix = torch.matmul(z, z.T)

    batch_size = z_i.shape[0]
    labels = torch.arange(batch_size).to(z.device)
    labels = torch.cat([labels, labels], dim=0)

    mask = torch.eye(2 * batch_size, dtype=torch.bool).to(z.device)
    similarity_matrix = similarity_matrix[~mask].view(2 * batch_size, -1)

    positives = torch.exp(torch.sum(z_i * z_j, dim=-1) / temperature)
    positives = torch.cat([positives, positives], dim=0)

    denominator = torch.exp(similarity_matrix / temperature).sum(dim=-1)
    loss = -torch.log(positives / denominator).mean()
    return loss

In [10]:
# --- Training Setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimCLR(simple_cnn()).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# model training
epochs = 1
model.train()
for epoch in range(epochs):
    total_loss = 0
    for batch_idx, (x_i, x_j) in enumerate(train_loader):
        x_i, x_j = x_i.to(device), x_j.to(device)
        z_i = model(x_i)
        z_j = model(x_j)

        loss = nt_xent_loss(z_i, z_j)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        print(f"  Batch {batch_idx+1}/{len(train_loader)} completed")

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{epochs}] Loss: {avg_loss:.4f}")