In [2]:
import torch

from dataset import *
from utils import *

batch_size = 16
training_data = dataset("rtrain/color", "rtrain/label", target_transform=target_remap())
test_data = dataset("rtest/color", "rtest/label", target_transform=target_remap())

train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, pin_memory=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, pin_memory=True,collate_fn=diff_size_collate)

In [3]:
import numpy as np


In [4]:
import torch.nn as nn

class EncoderPart(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(din, dout, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)   
        )
    
    def forward(self, x):
        return self.encoder(x)

class Encoder(nn.Module):
    def __init__(self, din):
        super().__init__()
        self.encoderPart1 = EncoderPart(din, 64)
        self.encoderPart2 = EncoderPart(64, 32)
        self.encoderPart3 = EncoderPart(32, 16)
    
    def forward(self, x):
        x = self.encoderPart1(x)
        x = self.encoderPart2(x)
        x = self.encoderPart3(x)
        return x

class DecoderPart(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(din, dout, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2)
        )
        
    def forward(self, x):
        return self.encoder(x)

class Decoder(nn.Module):
    def __init__(self, dout):
        super().__init__()
        self.decoderPart1 = DecoderPart(16, 16)
        self.decoderPart2 = DecoderPart(16, 32)
        self.decoderPart3 = DecoderPart(32, 64)
        self.decoderOut = nn.Sequential(
            nn.Conv2d(64, dout, kernel_size=3, padding=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.decoderPart1(x)
        x = self.decoderPart2(x)
        x = self.decoderPart3(x)
        x = self.decoderOut(x)
        return x

class Autoencoder(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.encoder = Encoder(din)
        self.decoder = Decoder(dout)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
        

In [5]:
model = Autoencoder(3, 3)
image, label = training_data[0]
print(image)
print(image.size())
pred = model(image)
print(pred)

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])
torch.Size([3, 512, 512])
tensor([[[0.5083, 0.5082, 0.5080,  ..., 0.5082, 0.5081, 0.5072],
         [0.5076, 0.5090, 0.5086,  ..., 0.5090, 0.5094, 0.5091],
         [0.5074, 0.5088, 0.5083,  ..., 0.5083, 0.5090, 0.

In [6]:
def train(dataloader, model, loss_fn, optimizer):
    losses = []
    model.train()
    target_batch_size = 64  #TODO before submission
    batch_size = 16          #TODO before submission
    to_print = True
    for batch, (X, _) in enumerate(tqdm(dataloader, total=len(dataloader), desc="Training")):
        X = X.to(device)
        # Compute prediction
        pred = model(X)
        
        # Compute loss
        loss = loss_fn(pred, X)
        losses.append(loss.item())
        
        loss.backward()
        
        if batch % (target_batch_size/batch_size) == 0:
            # Ensure gradients are reset to 0 for new batch
            optimizer.step()
            optimizer.zero_grad()
            # if to_print:
            #     print(f"memory: {torch.cuda.device_memory_used()}")
            #     to_print = False
        
    return np.mean(losses)
        
    

In [7]:
target_size = 512
interpolation = 'bilinear'
def eval(dataloader, model, loss_fn):
    model.eval()
    num_batches = len(dataloader)
    total_loss = 0.0
    with torch.no_grad():
        for batch, (X, _) in enumerate(tqdm(dataloader, total=len(dataloader), desc="Training")):
            X, meta_list = process_batch_forward(X, target_size=target_size)
            X = X.to(device)
            # Compute prediction
            pred = model(X)

            pred = process_batch_reverse(pred, meta_list, interpolation=interpolation)

            for p, label in zip(pred, X):
                
                # Move individual prediction and label to the device
                p = p.to(device).unsqueeze(0)  # Add batch dimension
                label = label.to(device).unsqueeze(0)  # Add batch dimension and ensure type is long
                
                # print(p.size(), flush=True)
                # print(label.size())
                # Calculate the loss for the current pair
                loss = loss_fn(p, label.squeeze(1))
                total_loss += loss.item()
            # Compute loss
            # loss = loss_fn(pred, X)
            # losses.append(loss.item())
    
    return total_loss / num_batches

In [None]:
from tqdm import tqdm
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
    
model = Autoencoder(3, 3).to(device)
loss_fn = nn.MSELoss()
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

best_eval_loss = np.inf 

for epoch in range(100):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_loss = train(train_dataloader, model, loss_fn, optimizer)
    eval_loss = eval(train_dataloader, model, loss_fn)
    print(f"Eval Loss: {eval_loss} for epoch {epoch}")
    with open("test.txt", "a") as file:
        file.write(f"Eval Loss: {eval_loss} for epoch {epoch}\n")
        
    if eval_loss <= best_eval_loss:
        best_eval_loss = eval_loss
        checkpoint = {"model": model.state_dict(),
              "optimizer": optimizer.state_dict()}
        torch.save(checkpoint, f"autoencoder/checkpoint_{epoch}.pytorch")
        img, label = training_data[0]
        img = img.to(device)

        res = model(img.unsqueeze(0))
        plt.imshow(res[0].permute(1,2,0).cpu().detach().numpy())
        plt.savefig(f"test{epoch}.png", format="png")
        plt.show()

    



In [None]:
import matplotlib.pyplot as plt

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

checkpoint = torch.load("autoencoder/checkpoint_20.pytorch")
model.load_state_dict(checkpoint["model"])
model.to(device)

img, label = training_data[0]
img = img.to(device)

res = model(img.unsqueeze(0))
print(res.size())
plt.imshow(res[0].permute(1,2,0).cpu().detach().numpy())
plt.show

RuntimeError: Error(s) in loading state_dict for Autoencoder:
	Missing key(s) in state_dict: "decoder.decoderOut.0.weight", "decoder.decoderOut.0.bias". 
	Unexpected key(s) in state_dict: "decoder.decoderOut.weight", "decoder.decoderOut.bias". 