In [1]:
from pathlib import Path
import os
import sys
from datetime import datetime
import numpy as np

In [2]:
PROJECT_DIR = Path(os.getcwd()).resolve().parent
PROJECT_DIR

PosixPath('/home/nico/Thesis/neural-artwork-caption-generator')

In [3]:
sys.path.append(str(PROJECT_DIR))

In [4]:
from src.models.multiclassification.model import ViTForMultiClassification
from src.models.multiclassification.data import get_multiclassification_dicts
from src.utils.dirutils import get_data_dir
from datasets import Dataset, load_from_disk
from transformers import ViTImageProcessor
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
multiclass_classifications, multilabel_classifications = get_multiclassification_dicts()
multiclass_features, multilabel_features = list(multiclass_classifications.keys()), list(multilabel_classifications.keys())
all_features = multiclass_features + multilabel_features
model = ViTForMultiClassification(multiclass_classifications, multilabel_classifications)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

dataset: Dataset = load_from_disk(
    get_data_dir() / "processed" / "multiclassification_dataset"
)
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

Some weights of the model checkpoint at google/vit-base-patch16-224 were not used when initializing ViTModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
def process(examples):
    examples["pixel_values"] = processor(examples["image"], return_tensors="pt").pixel_values
    new_examples = {}
    for feature in multiclass_features:
        new_examples[feature] = torch.tensor(examples[feature], dtype=torch.long, device=device)
    for feature in multilabel_features:
        new_examples[feature] = torch.tensor(examples[feature], dtype=torch.float, device=device)
    new_examples["pixel_values"] = examples["pixel_values"].to(device)
    return new_examples

In [7]:
dataset = dataset.with_transform(process)

In [8]:
train_loader = torch.utils.data.DataLoader(dataset["train"].select(range(1000)), batch_size=8)
validation_loader = torch.utils.data.DataLoader(dataset["validation"].select(range(1000)), batch_size=8)

In [9]:
# custom loss for multilabel classification, ignore elements that have 0 labels
def binary_cross_entropy_with_logits_ignore_no_labels(preds, targets):
    zero_labels = torch.sum(targets, dim=-1) == 0
    targets = targets.float()

    # Apply BCEWithLogitsLoss only to non-zero labels
    loss = F.binary_cross_entropy_with_logits(preds, targets, reduction="none")
    loss = torch.where(zero_labels.unsqueeze(-1), torch.zeros_like(loss), loss)
    loss = torch.mean(loss)

    return loss

In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)

In [11]:
def losses_fn(outputs, targets):
    losses = []
    for feature in multiclass_features:
        loss = F.cross_entropy(outputs[feature], targets[feature].squeeze(), ignore_index=-1)
        losses.append(loss)
    for feature in multilabel_features:
        loss = binary_cross_entropy_with_logits_ignore_no_labels(outputs[feature], targets[feature])
        losses.append(loss)

    return losses

In [12]:
def join_losses(model, losses):
    final_losses = []
    if model.log_vars is not None:
        for i, loss in enumerate(losses):
            final_losses.append(torch.exp(-model.log_vars[i]) * loss + model.log_vars[i]/2)
    else:
        final_losses = losses
    return sum(final_losses)

In [13]:
def train_one_epoch():
    epoch_loss = 0.
    epoch_label_losses = [0.] * len(all_features)

    n_batches = len(train_loader)

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for batch in tqdm(train_loader):
        # Every data instance is an input + label pair
        inputs, targets = batch["pixel_values"], {k: batch[k] for k in batch if k != "pixel_values"}

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        losses = losses_fn(outputs, targets)
        loss = join_losses(model, losses)
        
        # Update our running loss tally
        for i, _ in enumerate(all_features):
            epoch_label_losses[i] += losses[i].item()
        epoch_loss += loss.item()
    
        # Compute the gradients
        loss.backward()

        # Adjust learning weights
        optimizer.step()

    return epoch_loss / n_batches, [loss / n_batches for loss in epoch_label_losses]

In [14]:
def concatenate_batch_arrays(all_data):
    all_data_concat = {}
    for d in all_data:
        for key in d:
            if key not in all_data_concat:
                all_data_concat[key] = d[key]
            else:
                all_data_concat[key] = np.concatenate((all_data_concat[key], d[key]), axis=0)
    return all_data_concat

In [15]:
@torch.no_grad()
def eval():
    # Validate
    running_vloss = 0.
    running_label_vlosses = [0.] * len(all_features)
    all_voutputs = []
    all_vtargets = []
    
    for vbatch in tqdm(validation_loader):
        # Compute batch outputs
        vinputs, vtargets = vbatch["pixel_values"], {k: vbatch[k] for k in vbatch if k != "pixel_values"}
        voutputs = model(vinputs)

        # Compute batch losses
        vlosses = losses_fn(voutputs, vtargets)
        vloss = join_losses(model, vlosses)
        for i, _ in enumerate(all_features):
            running_label_vlosses[i] += vlosses[i].item()
        running_vloss += vloss.item()

        # Save predictions and targets for later
        all_voutputs.append(
            {k: v.detach().cpu().numpy() for k, v in voutputs.items()}
        )
        all_vtargets.append(
            {k: v.detach().cpu().numpy() for k, v in vtargets.items()}
        )

    avg_vloss = running_vloss / len(validation_loader)
    running_label_vlosses = [loss / len(validation_loader) for loss in running_label_vlosses]
    all_voutputs = concatenate_batch_arrays(all_voutputs)
    all_vtargets = concatenate_batch_arrays(all_vtargets)
    return avg_vloss, running_label_vlosses, None

In [16]:
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    train_loss, train_label_losses = train_one_epoch()
    for i, label in enumerate(all_features):
        writer.add_scalar(f"train/{label}_loss", train_label_losses[i], epoch_number + 1)
    writer.add_scalar("train/loss", train_loss, epoch_number + 1)
    torch.cuda.empty_cache()
    print(torch.cuda.memory_summary())

    # We don't need gradients on to do reporting
    model.train(False)
    avg_vloss, running_label_vlosses, metrics = eval() 
    print('LOSS train {} valid {}'.format(train_loss, avg_vloss))

    # Log valid losses
    for i, label in enumerate(all_features):
        writer.add_scalar(f"valid/{label}_loss", running_label_vlosses[i] / len(validation_loader), epoch_number + 1)
    writer.add_scalar("valid/loss", avg_vloss, epoch_number + 1)

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('train_vs_valid_loss',
                    {'train': train_loss, 'valid': avg_vloss},
                    epoch_number + 1)
    for k, v in metrics.items():
        writer.add_scalar(f'valid/{k}', v, epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    model_path = f"model-{timestamp}-{epoch_number + 1}.pt"
    torch.save({
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict()}, model_path)

    epoch_number += 1

EPOCH 1:


  4%|▍         | 5/125 [00:04<01:50,  1.08it/s]


KeyboardInterrupt: 