In [1]:
import os
from dataset.interface import SITE_DATASET_DIR

site_options = os.listdir(SITE_DATASET_DIR)

In [2]:
from dataset.interface import Site

site = Site(site_options[0])
groups = list(site)
images = {group.name: list(group) for group in groups}

In [3]:
from dataset.pytorch import ImageDataset, collate_largest
images_dataset = ImageDataset(images)
len(images_dataset)

4647

In [4]:
from torch.nn import Module, Conv2d, ConvTranspose2d, ReLU, MaxPool2d
class Encoder(Module):
    def __init__(self,):
        super().__init__()
        self.conv1 = Conv2d(1,2,5,)
        self.conv2 = Conv2d(2,4,5,) 
        self.activation1 = ReLU()
        self.max_pool1 = MaxPool2d(2)
        self.conv3 = Conv2d(4,8,5,)
        self.conv4 = Conv2d(8,16,5,)
        self.activation2 = ReLU()
        self.max_pool2 = MaxPool2d(2)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.activation1(x)
        x = self.max_pool1(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.activation2(x)
        x = self.max_pool2(x)
        return x
    
class Decoder(Module):
    def __init__(self,):
        super().__init__()
        self.deconv1 = ConvTranspose2d(16,8,4,stride=1)
        self.deconv2 = ConvTranspose2d(8,4,4,stride=2)
        self.activation1 = ReLU()
        self.deconv3 = ConvTranspose2d(4,2,4,stride=1)
        self.deconv4 = ConvTranspose2d(2,1,4,stride=2)
        self.activation2 = ReLU()
    
    def forward(self, x):
        x = self.deconv1(x)
        x = self.deconv2(x)
        x = self.activation1(x)
        x = self.deconv3(x)
        x = self.deconv4(x)
        x = self.activation2(x)
        return x

class Autoencoder(Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    def forward(self, x):
        return self.decoder(self.encoder(x))

In [5]:
from torch.nn import MSELoss
from torch.utils.data import DataLoader, random_split
from torch.optim import Adam
import torch

split = [int(len(images_dataset)*f) for f in [0.5,0.1,0.4]]
split[0] += len(images_dataset) - sum(split)
train_images, valid_images, test_images = random_split(images_dataset, split)
    
batch_size = 16
autoencoder = Autoencoder(Encoder(),Decoder())
criterion = MSELoss()
optimizer = Adam(autoencoder.parameters(),lr=.001,eps=1e-4)
dataloader = DataLoader(train_images, batch_size=batch_size, shuffle=True, collate_fn=collate_largest(4,4), num_workers=4, prefetch_factor=2, pin_memory=True)
valid_dataloader = DataLoader(valid_images, batch_size=batch_size, collate_fn=collate_largest(4,4), num_workers=4, prefetch_factor=2, pin_memory=True)
epochs=1


In [6]:
import os
if os.path.isfile("last.model"):
    autoencoder.load_state_dict(torch.load("last.model"))
autoencoder = autoencoder.to(memory_format=torch.channels_last).to("cuda")

In [7]:
from ipywidgets import Output, Label
import time

accumulation_steps = 4
for epoch in range(epochs):
    train_loss = 0.0
    start_epoch = time.time()
    for i, batch in enumerate(dataloader):
        batch = batch.to(memory_format=torch.channels_last).to("cuda")
        out = autoencoder(batch)
        loss = criterion(out, batch) /accumulation_steps
        loss.backward()
        
        if (i+1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            train_loss += loss.item()
    print("Loss:", train_loss, "Epoch Time", time.time() - start_epoch)

Loss: 157754.97888183594 Epoch Time 857.5731198787689
