In [27]:
import torch

from dataset import *
from utils import *

batch_size = 1
training_data = dataset("../ResizedTrainVal/color", "../ResizedTrainVal/label", target_transform=target_remap())
#test_data = dataset("rtest/color", "rtest/label", target_transform=target_remap())

train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, pin_memory=True)
#test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, pin_memory=True,collate_fn=diff_size_collate)

In [None]:
import numpy as np


In [None]:
import torch.nn as nn

class EncoderPart(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(din, dout, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)   
        )
    
    def forward(self, x):
        return self.encoder(x)

class Encoder(nn.Module):
    def __init__(self, din):
        super().__init__()
        self.encoderPart1 = EncoderPart(din, 64)
        self.encoderPart2 = EncoderPart(64, 32)
        self.encoderPart3 = EncoderPart(32, 16)
    
    def forward(self, x):
        x = self.encoderPart1(x)
        x = self.encoderPart2(x)
        x = self.encoderPart3(x)
        return x

class DecoderPart(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(din, dout, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2)
        )
        
    def forward(self, x):
        return self.encoder(x)

class Decoder(nn.Module):
    def __init__(self, dout):
        super().__init__()
        self.decoderPart1 = DecoderPart(16, 16)
        self.decoderPart2 = DecoderPart(16, 32)
        self.decoderPart3 = DecoderPart(32, 64)
        self.decoderOut = nn.Sequential(
            nn.Conv2d(64, dout, kernel_size=3, padding=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.decoderPart1(x)
        x = self.decoderPart2(x)
        x = self.decoderPart3(x)
        x = self.decoderOut(x)
        return x

class Autoencoder(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.encoder = Encoder(din)
        self.decoder = Decoder(dout)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
        

In [None]:
model = Autoencoder(3, 3)
image, label = training_data[0]
print(image)
print(image.size())
pred = model(image)
print(pred)

In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    """Converts image into patch embeddings"""
    def __init__(self, img_size=512, patch_size=16, in_channels=3, embed_dim=512):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2  # Number of patches
        self.patch_size = patch_size
        self.patch_dim = in_channels * patch_size * patch_size  # Flattened patch size
        self.projection = nn.Linear(self.patch_dim, embed_dim)  # Linear projection to embedding size
    
    def forward(self, x):
        B, C, H, W = x.shape  # Batch, Channels, Height, Width
        # Convert image (B, C, H, W) into patches:  
        # → (B, num_patches_height, num_patches_width, patch_size, patch_size, C):
        x = x.reshape(B, C, H // self.patch_size, self.patch_size, W // self.patch_size, self.patch_size).permute(0, 2, 4, 3, 5, 1)
        # → (B, num_patches, patch_dim):
        x = x.reshape(B, self.num_patches, -1)  # Flatten patches
        return self.projection(x)  # Project into embedding space



#Uses Transformer Encoder Layers (self-attention + feedforward network).
#Positional Encoding ensures order information is preserved.
class TransformerEncoder(nn.Module):
    def __init__(self, img_size=512, patch_size=16, embed_dim=512, num_heads=4, ff_dim=1024, num_layers=6):
        super().__init__()
        self.encoder_layers = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim),
            num_layers=num_layers
        )
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, embed_dim))  # # Positional encoding TODO: ADJUST FIR IMG SIZE
    
    def forward(self, x):
        x = x + self.pos_embedding[:, :x.shape[1], :]    # Add position information
        return self.encoder_layers(x)



# Uses cross-attention to refine the latent representation.
class TransformerDecoder(nn.Module):    
    def __init__(self, embed_dim=512, num_heads=4, ff_dim=1024, num_layers=6):
        super().__init__()
        self.decoder_layers = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim),
            num_layers=num_layers
        )
        #self.fc_out = nn.Linear(embed_dim, patch_dim)    # Convert embeddings back to patch values
    
    def forward(self, x, memory):
        return self.decoder_layers(x, memory)  # Keep the output in (B, P, embed_dim)
        #x = self.decoder_layers(x, memory)
        #return self.fc_out(x)


class ReconstructionTransformerAutoencoder(nn.Module):
    def __init__(self, img_size=521, patch_size=16, in_channels=3, embed_dim=512, num_heads=8, ff_dim=1024, num_layers=6):  #ff_dim=2*embed_dim
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.encoder = TransformerEncoder(img_size, self.patch_embed.patch_size, embed_dim, num_heads, ff_dim, num_layers)
        self.decoder = TransformerDecoder(embed_dim, num_heads, ff_dim, num_layers)
        self.fc_out = nn.Linear(embed_dim, self.patch_embed.patch_dim)    # Convert embeddings back to patch values

    
    def forward(self, x):
        print("[Reconstruction] Forward pass")
        patches = self.patch_embed(x)
        encoded = self.encoder(patches)
        decoded = self.decoder(encoded, encoded)
        reconstructed_patches = self.fc_out(decoded)
        B, P, D = reconstructed_patches.shape  # (batch, num_patches, patch_dim)
        print(f"[Reconstruction] (B, P, D): ({B}, {P}, {D}")
        print(f"[Reconstruction] embed_dim (512?) vs patch_dim ({self.patch_embed.patch_dim}): {D}")
        img_size = int(P ** 0.5) * self.patch_embed.patch_size
        return reconstructed_patches.view(B, 3, img_size, img_size)  # Reshape back to image





In [62]:
import torch
import torch.nn as nn

class SegmentationTransformerAutoencoder(nn.Module):
    def __init__(self, img_size=512, patch_size=16, in_channels=3, embed_dim=512, num_heads=4, ff_dim=1024, num_layers=6, num_classes=4):
        super().__init__()
        
        # Patch embedding to break image into patches
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        
        # Transformer encoder for feature extraction
        self.encoder = TransformerEncoder(img_size, self.patch_embed.patch_size, embed_dim, num_heads, ff_dim, num_layers)
        
        # Transformer decoder to output class logits (for segmentation)
        self.decoder = TransformerDecoder(embed_dim, num_heads, ff_dim, num_layers)
        
        # Segmentation head to output logits for 4 classes (background, boundary, cat, dog)
        self.segmentation_head = nn.Conv2d(embed_dim, num_classes, kernel_size=1)  # Output 4 channels

    def forward(self, x):
        print("[Segmentation] Forward pass")
        # Embed the input image into patches
        patches = self.patch_embed(x)
        encoded = self.encoder(patches)
        decoded = self.decoder(encoded, encoded)  # (B, num_patches, embedding_dim)
   
        # Reshape the decoded output to match image dimensions
        B, P, D = decoded.shape
        print(f"[Segmentation] (B, P, D): ({B}, {P}, {D})") #(1, 1024, 512)
        print(f"[Segmentation] patch_dim: {self.patch_embed.patch_dim}")  #768
        img_size = int(P ** 0.5) #* self.patch_embed.patch_size  # Calculate the original image size: 
                                                                # int(P ** 0.5) = sqrt(num_patches) = num patches along each side
        print(f"[Segmentation] calculated img_size: {img_size} -> img_size == 512: {img_size==512}")    # 512 True
        
        # Reshape the decoded output into a 2D spatial representation
        decoded = decoded.view(B, D, img_size, img_size)  # Reshape to (B, H, W, D) and change to (B, D, H, W) for Conv2d

        
        # Segmentation head: output logits for 4 classes (background, boundary, cat, dog)
        segmentation_logits = self.segmentation_head(decoded)  # (B, num_classes, H, W)
        return segmentation_logits  # Output segmentation mask (logits for each class)


In [63]:
from tqdm import tqdm
import torch
import imgaug.augmenters as iaa

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
    
model_1 = ReconstructionTransformerAutoencoder().to(device)
pytorch_total_params = sum(p.numel() for p in model_1.parameters())
print(f"Total parameters in model_1 (SegmentationTransformerAutoencoder): {pytorch_total_params}")
model_2 = SegmentationTransformerAutoencoder().to(device)
loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-3


losses = []
target_batch_size = 1  #TODO before submission
batch_size = 1          #TODO before submission
to_print = True
for model in [model_2, model_1]:
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    for batch, (X, _) in enumerate(tqdm(train_dataloader, total=len(train_dataloader), desc="Training")):
        # Resize images using imgaug
        #X = X.permute(0, 2, 3, 1).cpu().numpy()  # Convert from [B, C, H, W] to [B, H, W, C]
        #resize = iaa.Resize({"height": 256, "width": 256})
        #X = resize(image=X[0])  # Apply the resize operation
        
        # Convert back to PyTorch tensor and send to device
        #X = torch.tensor(X).permute(0, 3, 1, 2)  # Convert back to [B, C, H, W]
        
        X = X.to(device)
        # Compute prediction
        pred = model(X)
        
        # Compute loss
        loss = loss_fn(pred, X)
        losses.append(loss.item())
        
        loss.backward()
        
        # Ensure gradients are reset to 0 for new batch
        optimizer.step()
        optimizer.zero_grad()
           
        plt.imshow(X[0].permute(1,2,0).cpu().detach().numpy())
        plt.show()
        plt.imshow(pred[0].permute(1,2,0).cpu().detach().numpy())
        plt.show()
    print(f"losses: {losses}")


    



Total parameters in model_1 (SegmentationTransformerAutoencoder): 32855296


Training:   0%|          | 0/3673 [00:00<?, ?it/s]

[Segmentation] Forward pass
[Segmentation] (B, P, D): (1, 1024, 512)
[Segmentation] patch_dim: 768
[Segmentation] calculated img_size: 32 -> img_size == 512: False


Training:   0%|          | 0/3673 [00:00<?, ?it/s]


RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4

### Training

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    losses = []
    model.train()
    target_batch_size = 64  #TODO before submission
    batch_size = 16          #TODO before submission
    to_print = True
    for batch, (X, _) in enumerate(tqdm(dataloader, total=len(dataloader), desc="Training")):
        X = X.to(device)
        # Compute prediction
        pred = model(X)
        
        # Compute loss
        loss = loss_fn(pred, X)
        losses.append(loss.item())
        
        loss.backward()
        
        if batch % (target_batch_size/batch_size) == 0:
            # Ensure gradients are reset to 0 for new batch
            optimizer.step()
            optimizer.zero_grad()
            # if to_print:
            #     print(f"memory: {torch.cuda.device_memory_used()}")
            #     to_print = False
        
    return np.mean(losses)
        
    

### Evaluation

In [None]:
target_size = 512
interpolation = 'bilinear'

def eval(dataloader, model, loss_fn):
    model.eval()
    num_batches = len(dataloader)
    total_loss = 0.0
    with torch.no_grad():
        for batch, (X, _) in enumerate(tqdm(dataloader, total=len(dataloader), desc="Training")):
            X, meta_list = process_batch_forward(X, target_size=target_size)
            X = X.to(device)
            # Compute prediction
            pred = model(X)

            pred = process_batch_reverse(pred, meta_list, interpolation=interpolation)

            for p, label in zip(pred, X):
                
                # Move individual prediction and label to the device
                p = p.to(device).unsqueeze(0)  # Add batch dimension
                label = label.to(device).unsqueeze(0)  # Add batch dimension and ensure type is long
                
                # print(p.size(), flush=True)
                # print(label.size())
                # Calculate the loss for the current pair
                loss = loss_fn(p, label.squeeze(1))
                total_loss += loss.item()
            # Compute loss
            # loss = loss_fn(pred, X)
            # losses.append(loss.item())
    
    return total_loss / num_batches

### Run Autoencoder

In [None]:
from tqdm import tqdm
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
    
model = Autoencoder(3, 3).to(device)
loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

best_eval_loss = np.inf 

for epoch in range(100):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(train_dataloader, model, loss_fn, optimizer)
    eval_loss = eval(train_dataloader, model, loss_fn)
    print(f"Eval Loss: {eval_loss} for epoch {epoch}")
    with open("test.txt", "a") as file:
        file.write(f"Eval Loss: {eval_loss} for epoch {epoch}\n")
        
    if eval_loss <= best_eval_loss:
        best_eval_loss = eval_loss
        checkpoint = {"model": model.state_dict(),
              "optimizer": optimizer.state_dict()}
        torch.save(checkpoint, f"autoencoder/checkpoint_{epoch}.pytorch")
        img, label = training_data[0]
        img = img.to(device)

        res = model(img.unsqueeze(0))
        plt.imshow(res[0].permute(1,2,0).cpu().detach().numpy())
        plt.savefig(f"test{epoch}.png", format="png")
        plt.show()

    



### Run Transformer Autoencoder

In [None]:
from tqdm import tqdm
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
    
model = ReconstructionTransformerAutoencoder().to(device) #using default img_size=521, patch_size=16, in_channels=3, embed_dim=512, num_heads=8, ff_dim=1024, num_layers=6
loss_fn = nn.MSELoss()
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

best_eval_loss = np.inf 

for epoch in range(100):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(train_dataloader, model, loss_fn, optimizer)
    eval_loss = eval(train_dataloader, model, loss_fn)
    print(f"Eval Loss: {eval_loss} for epoch {epoch}")
    with open("test.txt", "a") as file:
        file.write(f"Eval Loss: {eval_loss} for epoch {epoch}\n")
        
    if eval_loss <= best_eval_loss:
        best_eval_loss = eval_loss
        checkpoint = {"model": model.state_dict(),
              "optimizer": optimizer.state_dict()}
        torch.save(checkpoint, f"autoencoder/checkpoint_{epoch}.pytorch")
        img, label = training_data[0]
        img = img.to(device)

        res = model(img.unsqueeze(0))
        plt.imshow(res[0].permute(1,2,0).cpu().detach().numpy())
        plt.savefig(f"test{epoch}.png", format="png")
        plt.show()

    



In [None]:
import matplotlib.pyplot as plt

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

checkpoint = torch.load("autoencoder/checkpoint_20.pytorch")
model.load_state_dict(checkpoint["model"])
model.to(device)

img, label = training_data[0]
img = img.to(device)

res = model(img.unsqueeze(0))
print(res.size())
plt.imshow(res[0].permute(1,2,0).cpu().detach().numpy())
plt.show