In [None]:
import torch

# Device configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
EPOCHS = 10
BATCH_SIZE = 64
LR = 0.001

In [None]:
import os
import zipfile
import opendatasets as od

# Replace 'dataset_name' with the name of the dataset you want to download
od.download('https://www.kaggle.com/datasets/vermaavi/food11')

In [None]:
import os
from torch.utils.data import Dataset
from PIL import Image, ImageDraw, ImageFont

# Custom dataset that inherits from 
class Food11Dataset(Dataset):
    def __init__(self, dir, limit=None, transform=None):
        self.dir = dir
        self.transform = transform

        if limit is None:
            self.file_list = [file for file in os.listdir(dir) if file.endswith('.jpg')]
        else:
            self.file_list = [file for file in os.listdir(dir) if file.endswith('.jpg')][:limit]
        # Get the list of files in the directory, up to the limit
        

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        path = os.path.join(self.dir, self.file_list[idx])
        
        image = Image.open(path)

        noisy = self.add_watermark(image)
        
        if self.transform:
            image = self.transform(image)
            noisy = self.transform(noisy)

        label = int(self.file_list[idx].split('_')[0])

        return image, noisy, label
    
    def add_watermark(self, image):
        noisy = image.copy()

        draw = ImageDraw.Draw(noisy)

        text = '01.11.2023'

        font = ImageFont.truetype('arial.ttf', 30)

        # Specify the color of the watermark (in this case, white)
        text_color = (255, 255, 255)

        # Draw the watermark on the image
        draw.text((10, 10), text, font=font, fill=text_color)

        return noisy

In [None]:
import torch.nn as nn

class Autoencoder(nn.Module):

    def __init__(self):
        super(Autoencoder, self).__init__()

        # N, 1, 28, 28
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), # N, 16, 14, 14
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # N, 32, 7, 7
            nn.ReLU(), 
            nn.Conv2d(32, 64, kernel_size=7), # N, 64, 1, 1
        )
        
        # N, 64, 1, 1
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=7), # N, 32, 7, 7
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # N, 16, 14, 14
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # N, 1, 28, 28
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
import torch.nn as nn

class Autoencoder(nn.Module):

    def __init__(self):
        super(Autoencoder, self).__init__()

        # N, 3, 256, 256
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),  # N, 16, 128, 128
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),  # N, 32, 64, 64
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # N, 64, 32, 32
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  # N, 128, 16, 16
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),  # N, 256, 8, 8
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),  # N, 512, 4, 4
        )

        # N, 512, 4, 4
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),  # N, 256, 8, 8
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),  # N, 128, 16, 16
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),  # N, 64, 32, 32
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),  # N, 32, 64, 64
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),  # N, 16, 128, 128
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1),  # N, 3, 256, 256
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [None]:
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
])

training_data = Food11Dataset('food11/training', transform=transform, limit=1000)

validate_data = Food11Dataset('food11/validation', transform=transform)

train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True)

validate_dataloader = DataLoader(validate_data, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
import matplotlib.pyplot as plt

# Plot some training images

fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))

for batch in train_dataloader:
    for i in range(10):
        axes[0][i].imshow(batch[0][i][0])
        axes[1][i].imshow(batch[1][i][0])
    break

In [None]:
from tqdm import tqdm

model = Autoencoder()
model.to(DEVICE)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

for epoch in range(EPOCHS):

    for (original, noisy, labels) in tqdm(train_dataloader):
        noisy = noisy.to(DEVICE)
        original = original.to(DEVICE)

        output = model(noisy)
        loss = criterion(output, original)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    loss = loss.item()

    # Save loss to a file called loss.txt
    with open('loss.txt', 'a') as f:
        f.write(str(loss) + '\n')

    # Save the model as latest.pth
    torch.save(model.state_dict(), 'latest.pth')

    print(f'Epoch [{epoch+1}, Loss: {loss:.4f}')

In [None]:
import torch
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Testing


transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
])

validate_data = Food11Dataset('food11/validation', transform=transform)

validate_dataloader = DataLoader(validate_data, batch_size=BATCH_SIZE, shuffle=True)

model = Autoencoder()
model.to(DEVICE)

model.load_state_dict(torch.load('latest.pth'))

model.eval()


fig, axes = plt.subplots(nrows=2, ncols=5, sharex=True, sharey=True, figsize=(10,4))

for batch in validate_dataloader:
    for i in range(5):
        axes[0][i].imshow(batch[0][i][0])
        axes[1][i].imshow(batch[1][i][0])
    break