In [37]:
import torch

from dataset import *


batch_size = 2
training_data = dataset("../ResizedTrainVal/color", "../ResizedTrainVal/label", target_transform=target_remap())
#test_data = dataset("Test/color", "Test/label", target_transform=target_remap())

train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, pin_memory=True)
#test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, pin_memory=True,
                             #collate_fn=diff_size_collate)

In [38]:
import numpy as np


In [39]:
import torch.nn as nn

class EncoderPart(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(din, dout, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)   
        )
    
    def forward(self, x):
        return self.encoder(x)

class Encoder(nn.Module):
    def __init__(self, din):
        super().__init__()
        self.encoderPart1 = EncoderPart(din, 64)
        self.encoderPart2 = EncoderPart(64, 32)
        self.encoderPart3 = EncoderPart(32, 16)
    
    def forward(self, x):
        x = self.encoderPart1(x)
        x = self.encoderPart2(x)
        x = self.encoderPart3(x)
        return x

class DecoderPart(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(din, dout, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2)
        )
        
    def forward(self, x):
        return self.encoder(x)

class Decoder(nn.Module):
    def __init__(self, dout):
        super().__init__()
        self.decoderPart1 = DecoderPart(16, 16)
        self.decoderPart2 = DecoderPart(16, 32)
        self.decoderPart3 = DecoderPart(32, 64)
        self.decoderOut = nn.Conv2d(64, dout, kernel_size=3, padding=1)
        
    def forward(self, x):
        x = self.decoderPart1(x)
        x = self.decoderPart2(x)
        x = self.decoderPart3(x)
        x = self.decoderOut(x)
        return x

class Autoencoder(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.encoder = Encoder(din)
        self.decoder = Decoder(dout)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
        

In [40]:
model = Autoencoder(3, 3)
image, label = training_data[0]
print(image)
print(image.size())
pred = model(image)
print(pred)

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])
torch.Size([3, 512, 512])
tensor([[[-0.0512, -0.0411, -0.0421,  ..., -0.0427, -0.0423, -0.0391],
         [-0.0534, -0.0469, -0.0478,  ..., -0.0496, -0.0473, -0.0481],
         [-0.0534, -0.0475, -0.0489,  ..., -0.

In [41]:
def train(dataloader, model, loss_fn, optimizer):
    losses = []
    model.train()
    target_batch_size = 32  #TODO before submission
    batch_size = 4          #TODO before submission
    
    for batch, (X, _) in enumerate(tqdm(dataloader, total=len(dataloader), desc="Training")):
        X = X.to(device)
        # Compute prediction
        pred = model(X)
        # Compute loss
        loss = loss_fn(pred, X)
        losses.append(loss.item())
        
        loss.backward()
        
        if batch % (target_batch_size/batch_size) == 0:
            # Ensure gradients are reset to 0 for new batch
            optimizer.zero_grad()
            optimizer.step()
        
    return np.mean(losses)
        
    

In [42]:
def eval(dataloader, model, loss_fn):
    losses = []
    model.eval()
    target_batch_size = 32  #TODO before submission
    batch_size = 4          #TODO before submission
    with torch.no_grad():
        for batch, (X, _) in enumerate(tqdm(dataloader, total=len(dataloader), desc="Training")):
            X = X.to(device)
            # Compute prediction
            pred = model(X)
            # Compute loss
            loss = loss_fn(pred, X)
            losses.append(loss.item())
        
    return np.mean(losses)

In [43]:


if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
    
model = Autoencoder(3, 3).to(device)
loss_fn = nn.MSELoss()
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

best_eval_loss = np.inf 

for epoch in range(100):
    print(f"---------------------------------------- Epoch {epoch}")
    train_loss = train(train_dataloader, model, loss_fn, optimizer)
    eval_loss = eval(train_dataloader, model, loss_fn)
    if epoch % 5 == 0:
        print(f"Eval Loss: {eval_loss} for epoch {epoch}")
    if eval_loss < best_eval_loss:
        best_eval_loss = eval_loss
        checkpoint = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict()
        }
        torch.save(checkpoint, f"autoencoder/checkpoint_{epoch}.pytorch")
    



---------------------------------------- Epoch 0


KeyboardInterrupt: 