In [1]:
import sys
import os
import json
import tqdm
import argparse
import torch
import torch.nn as nn
import numpy as np
import random
from collections import Counter
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision.io import read_image
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
from torch.optim import Adam
from torchvision.io import read_image
import torcheval.metrics as metrics

In [2]:
args = {
    'modelfile':'wikiart.pth',
    'trainingdir':'/home/guserbto@GU.GU.SE/wikiart/train',
    'validationdir': '/home/guserbto@GU.GU.SE/wikiart/valid',
    'testingdir': '/home/guserbto@GU.GU.SE/wikiart/test',
    'device': 'cuda:2',
    'epochs':20,
    'batch_size':32,
    
}

modelfile = args['modelfile']
trainingdir = args['trainingdir']
validationdir = args['validationdir']
testingdir = args['testingdir']
device = args['device']
epochs = args['epochs']
batch_size = args['batch_size']

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
class WikiArtImage:
    def __init__(self, imgdir, label, filename):
        self.imgdir = imgdir
        self.label = label
        self.filename = filename
        self.image = None
        self.loaded = False

    def get(self):
        if not self.loaded:
            self.image = read_image(os.path.join(self.imgdir, self.label, self.filename)).float()
            self.loaded = True
            self.image /= 255.0

        return self.image

In [4]:
class WikiArtDataset(Dataset):
    def __init__(self, imgdir, device="cpu"):
        walking = os.walk(imgdir)
        filedict = {}
        indices = []
        classes = set()
        print("Gathering files for {}".format(imgdir))
        for item in walking:
            sys.stdout.write('.')
            arttype = os.path.basename(item[0])
            artfiles = item[2]
            for art in artfiles:
                filedict[art] = WikiArtImage(imgdir, arttype, art)
                indices.append(art)
                classes.add(arttype)
        print("...finished")
        self.filedict = filedict
        self.imgdir = imgdir
        self.indices = indices
        self.classes = list(classes)
        self.device = device
        
    def __len__(self):
        return len(self.filedict)

    def __getitem__(self, idx):
        imgname = self.indices[idx]
        imgobj = self.filedict[imgname]
        ilabel = self.classes.index(imgobj.label)
        image = imgobj.get().to(self.device)

        return image, ilabel
        
    def get_image_by_filename(self):
        if filename not in self.filedict:
            raise ValueError(f"Filename '{filename}' not found in the dataset.")
        
        # Retrieve the image object and label
        imgobj = self.filedict[filename]
        ilabel = self.classes.index(imgobj.label)
        image = imgobj.get().to(self.device)

        return image, ilabel

In [5]:
class WikiArtModel(nn.Module):
    def __init__(self, num_classes=27):
        super(WikiArtModel, self).__init__()
        self.conv2d_1 = nn.Conv2d(3, 32, kernel_size=4, padding=2)
        self.maxpool2d_1 = nn.MaxPool2d(kernel_size=2, padding=1)
        
        self.conv2d_2 = nn.Conv2d(32, 64, kernel_size=4, padding=2)
        self.maxpool2d_2 = nn.MaxPool2d(kernel_size=2, padding=1)

        self.flatten = nn.Flatten()
        self.batchnorm1d = nn.BatchNorm1d(64 * 106 * 106)
        self.linear1 = nn.Linear(64 * 106 * 106, 300)
        self.dropout = nn.Dropout(0.3)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(300, num_classes)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, image):
        output = self.conv2d_1(image)
        output = self.relu(output)
        output = self.maxpool2d_1(output)

        output = self.conv2d_2(output)
        output = self.relu(output)
        output = self.maxpool2d_2(output)

        output = self.flatten(output)
        output = self.batchnorm1d(output)
        output = self.linear1(output)
        output = self.dropout(output)
        output = self.relu(output)
        output = self.linear2(output)
        return self.softmax(output)

In [6]:
# Prepping label weights for WeightedRandomSampler

dataset = WikiArtDataset(trainingdir, device)

labels = []
for image, label in dataset:
    labels.append(label)

label_counts = Counter(labels)

label_weights_dict = {} 
for label, count in label_counts.items():
    label_weights_dict[label] = 1/count

weight_per_dataset_item = []
for label in labels:
    weight_per_dataset_item.append(label_weights_dict[label])


weight_per_dataset_item = torch.tensor(weight_per_dataset_item, dtype=torch.float)

Gathering files for /home/guserbto@GU.GU.SE/wikiart/train
...............................finished


In [7]:
# Weighted Random Sampler

sampler = WeightedRandomSampler(weights=weight_per_dataset_item, num_samples=len(weight_per_dataset_item), replacement=True)

In [8]:
# TRAIN

print("Time to train...")

traindataset = WikiArtDataset(trainingdir, device)

def train(epochs=3, batch_size=32, modelfile=None, device="cpu"):
    train_loader = DataLoader(traindataset, batch_size=batch_size, sampler=sampler)

    model = WikiArtModel().to(device)
    optimizer = Adam(model.parameters(), lr=0.001)
    criterion = nn.NLLLoss().to(device)
    
    for epoch in range(epochs):
        print("Starting epoch {}".format(epoch))
        accumulate_loss = 0
        for batch_id, batch in enumerate(tqdm.tqdm(train_loader)):
            X, y = batch
            y = y.to(device)
            optimizer.zero_grad()
            output = model(X)
            loss = criterion(output, y)
            loss.backward()
            accumulate_loss += loss
            optimizer.step()

        print("In epoch {}, loss = {}".format(epoch, accumulate_loss))

    if modelfile:
        torch.save(model.state_dict(), modelfile)

    return model

Time to train...
Gathering files for /home/guserbto@GU.GU.SE/wikiart/train
...............................finished


In [11]:
model = train(args["epochs"], args["batch_size"], modelfile=args["modelfile"], device=device)

OutOfMemoryError: CUDA out of memory. Tried to allocate 824.00 MiB. GPU 2 has a total capacity of 10.90 GiB of which 161.31 MiB is free. Process 627257 has 8.38 GiB memory in use. Including non-PyTorch memory, this process has 2.37 GiB memory in use. Of the allocated memory 2.06 GiB is allocated by PyTorch, and 155.83 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# TEST

print("Testing...")

testingdataset = WikiArtDataset(testingdir, device)

def test(modelfile=None, device="cpu"):
    loader = DataLoader(testingdataset, batch_size=1)

    model = WikiArtModel()
    model.load_state_dict(torch.load(modelfile, weights_only=True))
    model = model.to(device)
    model.eval()

    predictions = []
    truth = []
    for batch_id, batch in enumerate(tqdm.tqdm(loader)):
        X, y = batch
        y = y.to(device)
        output = model(X)
        predictions.append(torch.argmax(output).unsqueeze(dim=0))
        truth.append(y)

    predictions = torch.concat(predictions)
    truth = torch.concat(truth)
    metric = metrics.MulticlassAccuracy()
    metric.update(predictions, truth)
    print("Accuracy: {}".format(metric.compute()))
    confusion = metrics.MulticlassConfusionMatrix(27)
    confusion.update(predictions, truth)
    print("Confusion Matrix\n{}".format(confusion.compute()))

In [None]:
test(modelfile=modelfile, device=device)