In [None]:
import sys
import os
import json
import tqdm
import argparse
import torch
import torch.nn as nn
import numpy as np
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
from torch.optim import Adam
from torchvision.io import read_image
import torcheval.metrics as metrics

In [None]:
args = {
    'modelfile':'autoencoder-1.pth',
    'trainingdir':'/home/guserbto@GU.GU.SE/wikiart/train',
    'validationdir': '/home/guserbto@GU.GU.SE/wikiart/valid',
    'testingdir': '/home/guserbto@GU.GU.SE/wikiart/test',
    'device': 'cuda:3',
    'epochs':20,
    'batch_size':32,
    
}

modelfile = args['modelfile']
trainingdir = args['trainingdir']
validationdir = args['validationdir']
testingdir = args['testingdir']
device = args['device']
epochs = args['epochs']
batch_size = args['batch_size']


seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
class WikiArtImage:
    def __init__(self, imgdir, label, filename):
        self.imgdir = imgdir
        self.label = label
        self.filename = filename
        self.image = None
        self.loaded = False

    def get(self):
        if not self.loaded:
            self.image = read_image(os.path.join(self.imgdir, self.label, self.filename)).float()
            self.image /= 255.0
            self.loaded = True

        return self.image

In [None]:
class WikiArtDataset(Dataset):
    def __init__(self, imgdir, device="cpu"):
        walking = os.walk(imgdir)
        filedict = {}
        indices = []
        classes = set()
        print("Gathering files for {}".format(imgdir))
        for item in walking:
            sys.stdout.write('.')
            arttype = os.path.basename(item[0])
            artfiles = item[2]
            for art in artfiles:
                filedict[art] = WikiArtImage(imgdir, arttype, art)
                indices.append(art)
                classes.add(arttype)
        print("...finished")
        self.filedict = filedict
        self.imgdir = imgdir
        self.indices = indices
        self.classes = list(classes)
        self.device = device
        
    def __len__(self):
        return len(self.filedict)

    def __getitem__(self, idx):
        imgname = self.indices[idx]
        imgobj = self.filedict[imgname]
        ilabel = self.classes.index(imgobj.label)
        image = imgobj.get().to(self.device)

        return image, ilabel

    def get_filename(self, idx):
        imgname = self.indices[idx]
        imgobj = self.filedict[imgname]
        ilabel = self.classes.index(imgobj.label)
        image = imgobj.get().to(self.device)

        return imgname


    def get_image_by_filename(self, filename):
        if filename not in self.filedict:
            raise ValueError(f"Filename '{filename}' not found in the dataset.")
        
        # Retrieve the image object and label
        imgobj = self.filedict[filename]
        ilabel = self.classes.index(imgobj.label)
        image = imgobj.get().to(self.device)

        return image, ilabel

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, num_classes=27):
        super(AutoEncoder, self).__init__()

        self.encoder = nn.Sequential(
          nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1), 
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        )

   
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )


    def encode(self, image):
        return self.encoder(image)

    def decode(self, image):
        return self.decoder(image)

    def forward(self, image):
        encoded = self.encoder(image)
        decoded = self.decoder(encoded)
        return decoded
  

In [None]:
# TRAIN

traindataset = WikiArtDataset(trainingdir, device)

def train(epochs=3, batch_size=32, modelfile=None, device="cpu"):
    train_loader = DataLoader(traindataset, batch_size=batch_size, shuffle=True)

    model = AutoEncoder().to(device)
    optimizer = Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss().to(device)

    train_loss = []
    
    for epoch in range(epochs):
        print("Starting epoch {}".format(epoch))
        accumulate_loss = 0
        for batch_id, batch in enumerate(tqdm.tqdm(train_loader)):
            X, _ = batch
            X = X.to(device)
            optimizer.zero_grad()
            output = model(X)
            loss = criterion(output, X)
            loss.backward()
            accumulate_loss += loss
            optimizer.step()

        print("In epoch {}, epoch_loss = {}".format(epoch +1, accumulate_loss))

    if modelfile:
        torch.save(model.state_dict(), modelfile)

    return model

In [None]:
model = train(args["epochs"], args["batch_size"], modelfile=args["modelfile"], device=device)

In [None]:
# TEST

testingdataset = WikiArtDataset(testingdir, device)

def test(modelfile=None, device="cpu"):
    loader = DataLoader(testingdataset, batch_size=1)

    model = AutoEncoder()
    model.load_state_dict(torch.load(modelfile, map_location=device, weights_only=True))
    model = model.to(device)
    model.eval()

    criterion = nn.MSELoss()
    total_loss = 0.0
    num_batches = 0

    encoded_images = []
    labels = []
    for batch_id, batch in enumerate(tqdm.tqdm(loader)):
        X, y = batch
        X = X.to(device)
        
        with torch.no_grad():
            encoded_output = model.encode(X)
            decoded_output = model.decode(encoded_output)
            encoded_images.append(encoded_output)
            labels.append(y)
                
        loss = criterion(decoded_output, X)
        total_loss += loss.item()
        num_batches += 1

    avg_loss = total_loss / num_batches
    print("Average MSE (Reconstruction Error): ", avg_loss)

    return encoded_images, labels

In [None]:
encoded_images, labels = test(modelfile, device)

In [None]:
# Evaluate model on a given filename

testdataset = WikiArtDataset(testingdir, device)

filename = testdataset.get_filename(5)

model = AutoEncoder().to(device)
model.load_state_dict(torch.load(modelfile, map_location=device, weights_only=True))

img, label = testdataset.get_image_by_filename(str(filename))

input_img = F.to_pil_image(img)
plt.imshow(input_img)
print("Input image")
plt.show()

model.eval()

with torch.no_grad():
    output = model(img.to(device))

output_img = F.to_pil_image(output)
plt.imshow(output_img)
print("Reconstructed image")
plt.show()

In [None]:
# Preprocess encoded images and labels for clustering

def preprocess(encoded_images, labels):

    encoded_images = torch.concat(encoded_images)
    labels = torch.concat(labels)
    
    #  4D tensor to a 2D tensor
    batch_size, num_channels, height, width = encoded_images.shape
    encoded_images = encoded_images.view(batch_size, -1)
    

    encoded_images_numpy = encoded_images.cpu().numpy()
    labels_numpy = labels.cpu().numpy()

    print("Done preprocessing")

    return encoded_images_numpy, labels_numpy 


In [None]:
encoded_images_numpy, labels_numpy = preprocess(encoded_images, labels)

In [None]:
kmeans = KMeans(n_clusters=27)
clusters = kmeans.fit_predict(encoded_images_numpy)

print("cluster labels:", clusters)
print("labels shape:", clusters.shape)

In [None]:
pca = PCA(n_components=2)
encoded_images_2d = pca.fit_transform(encoded_images_numpy)

print(encoded_images_2d)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20, 7))

scatter1 = axes[0].scatter(
    encoded_images_2d[:, 0], 
    encoded_images_2d[:, 1], 
    c=labels_numpy,     
    alpha=0.5
)
axes[0].set_title('PCA of Encoded Images with Actual Labels')
fig.colorbar(scatter1, ax=axes[0])

scatter2 = axes[1].scatter(
    encoded_images_2d[:, 0], 
    encoded_images_2d[:, 1], 
    c=clusters,     
    alpha=0.5
)
axes[1].set_title('PCA of Encoded Images with Cluster Labels')
fig.colorbar(scatter2, ax=axes[1])


plt.show()