In [None]:
import sys
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
from torch.optim import Adam
import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
import torchvision.transforms.functional as F
from torch.optim import Adam
import torcheval.metrics as metrics
import json
import argparse

In [None]:
args = {
    'modelfile':'wikiart.pth',
    'trainingdir':'/home/guserbto@GU.GU.SE/wikiart/train',
    'testingdir': '/home/guserbto@GU.GU.SE/wikiart/test',
    'device': 'cuda:3',
    'epochs':10,
    'batch_size':32,
    
}

modelfile = args['modelfile']
trainingdir = args['trainingdir']
testingdir = args['testingdir']
device = args['device']
epochs = args['epochs']
batch_size = args['batch_size']

In [None]:
class WikiArtImage:
    def __init__(self, imgdir, label, filename):
        self.imgdir = imgdir
        self.label = label
        self.filename = filename
        self.image = None
        self.loaded = False

    def get(self):
        if not self.loaded:
            self.image = read_image(os.path.join(self.imgdir, self.label, self.filename)).float()
            self.loaded = True

        return self.image

In [None]:
class WikiArtDataset(Dataset):
    def __init__(self, imgdir, device="cpu"):
        walking = os.walk(imgdir)
        filedict = {}
        indices = []
        classes = set()
        print("Gathering files for {}".format(imgdir))
        for item in walking:
            sys.stdout.write('.')
            arttype = os.path.basename(item[0])
            artfiles = item[2]
            for art in artfiles:
                filedict[art] = WikiArtImage(imgdir, arttype, art)
                indices.append(art)
                classes.add(arttype)
        print("...finished")
        self.filedict = filedict
        self.imgdir = imgdir
        self.indices = indices
        self.classes = list(classes)
        self.device = device
        
    def __len__(self):
        return len(self.filedict)

    def __getitem__(self, idx):
        imgname = self.indices[idx]
        imgobj = self.filedict[imgname]
        ilabel = self.classes.index(imgobj.label)
        image = imgobj.get().to(self.device)

        return image, ilabel

In [None]:
class WikiArtModel(nn.Module):
    def __init__(self, num_classes=27):
        super(WikiArtModel, self).__init__()
        self.conv2d_1 = nn.Conv2d(3, 32, kernel_size=4, padding=2)  # Output: (32, 416, 416)
        self.maxpool2d_1 = nn.MaxPool2d(kernel_size=2, padding=1) 
        self.conv2d_2 = nn.Conv2d(32, 64, kernel_size=4, padding=2) 
        self.maxpool2d_2 = nn.MaxPool2d(kernel_size=2, padding=1)   

        self.flatten = nn.Flatten()
        self.batchnorm1d = nn.BatchNorm1d(64 * 106 * 106)
        self.linear1 = nn.Linear(64 * 106 * 106, 300)
        self.dropout = nn.Dropout(0.3)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(300, num_classes)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, image):
        output = self.conv2d_1(image)
        output = self.relu(output)
        output = self.maxpool2d_1(output)

        output = self.conv2d_2(output)
        output = self.relu(output)
        output = self.maxpool2d_2(output)

        output = self.flatten(output)
        output = self.batchnorm1d(output)
        output = self.linear1(output)
        output = self.dropout(output)
        output = self.relu(output)
        output = self.linear2(output)
        return self.softmax(output)

In [None]:
# TRAIN

print("Time to train...")


traindataset = WikiArtDataset(trainingdir, device)

the_image, the_label = traindataset[5]
print(the_image, the_image.size())

def train(epochs=3, batch_size=32, modelfile=None, device="cpu"):
    loader = DataLoader(traindataset, batch_size=batch_size, shuffle=True)

    model = WikiArtModel().to(device)
    optimizer = Adam(model.parameters(), lr=0.01)
    criterion = nn.NLLLoss().to(device)
    
    for epoch in range(epochs):
        print("Starting epoch {}".format(epoch))
        accumulate_loss = 0
        for batch_id, batch in enumerate(tqdm.tqdm(loader)):
            X, y = batch
            y = y.to(device)
            optimizer.zero_grad()
            output = model(X)
            loss = criterion(output, y)
            loss.backward()
            accumulate_loss += loss
            optimizer.step()

        print("In epoch {}, loss = {}".format(epoch, accumulate_loss))

    if modelfile:
        torch.save(model.state_dict(), modelfile)

    return model

In [None]:
model = train(args["epochs"], args["batch_size"], modelfile=args["modelfile"], device=device)

In [None]:
# TEST

print("Testing...")

testingdataset = WikiArtDataset(testingdir, device)

def test(modelfile=None, device="cpu"):
    loader = DataLoader(testingdataset, batch_size=1)

    model = WikiArtModel()
    model.load_state_dict(torch.load(modelfile, weights_only=True))
    model = model.to(device)
    model.eval()

    predictions = []
    truth = []
    for batch_id, batch in enumerate(tqdm.tqdm(loader)):
        X, y = batch
        y = y.to(device)
        output = model(X)
        predictions.append(torch.argmax(output).unsqueeze(dim=0))
        truth.append(y)

    #print("predictions {}".format(predictions))
    #print("truth {}".format(truth))
    predictions = torch.concat(predictions)
    truth = torch.concat(truth)
    metric = metrics.MulticlassAccuracy()
    metric.update(predictions, truth)
    print("Accuracy: {}".format(metric.compute()))
    confusion = metrics.MulticlassConfusionMatrix(27)
    confusion.update(predictions, truth)
    print("Confusion Matrix\n{}".format(confusion.compute()))

In [None]:
test(modelfile=modelfile, device=device)