In [None]:
!pip install --upgrade wandb -qqq

import torch
from torch.optim import Adam
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torchvision.models import resnet50, densenet201, ResNet50_Weights, DenseNet201_Weights
import torchvision.transforms as transforms
import lightning.pytorch as pl
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from imgaug import augmenters as iaa
import pandas as pd
import string
import os
import wandb
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import copy
from typing import Tuple, List
from nltk.translate.bleu_score import sentence_bleu
import tkinter
import tkinter.filedialog
from tqdm import tqdm 
import inspect
import sys

#append directory to path variable
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

torch.set_float32_matmul_precision('high')

user = "t_buess"
project = "del_mc2"
os.environ["WANDB_NOTEBOOK_NAME"] = os.path.abspath("") + "\\del_mc2_tobias_buess.ipynb" 

wandb.login()

In [2]:
def tmp():
    import torch
    from torch.nn.utils.rnn import pad_sequence
    from torch.utils.data import DataLoader, Dataset
    from PIL import Image
    from imgaug import augmenters as iaa
    import pandas as pd
    import string
    import os
    import numpy as np
    import copy
    from tqdm import tqdm 

    class Tokenizer:
        def __init__(self, min_token_count:int=3) -> None:
            """ Diese Klasse ist für die Tokenisierung der Captions zuständig
            Args:
                min_token_count (int): minimale Vorkommenshäufigkeit für einen bestimmten Token
            """
            
            self.tok_to_index = {"<SOS>":0, "<EOS>":1, "<PAD>":2} #token zu index
            self.index_to_token = ["<SOS>", "<EOS>", "<PAD>"] #index zu token
            self.min_token_count = min_token_count 
        
        def build_vocab(self, corpus:list) -> None:
            """Generiere Vokabular aus korpus
            Args:
                corpus (list): Liste von Strings
            """

            i_start = len(self.index_to_token) #setze start i

            token_count = {} #hier werden die Vorkommnisse der Tokens gezählt

            #build vocab
            for element in corpus:
                for token in Tokenizer.tokenize_text(element): #tokenisiere
                    if token not in self.index_to_token: #falls token noch nicht registriert
                        token_count[token] = token_count.get(token, 0) + 1 #addiere token Vorkommnis

                        if token_count[token] >= self.min_token_count: #überprüfe ob minimale Anzahl erreicht
                            self.tok_to_index[token] = i_start #speichere token
                            self.index_to_token.append(token) #speichere token
                            i_start += 1

        def numericalize(self, caption:str) -> torch.Tensor:
            """Konvertiere caption zu vektor aus indices
            Args:
                caption (str): caption

            Returns:
                caption als Tensor
            """

            return torch.tensor([self.tok_to_index[token] for token in Tokenizer.tokenize_text(caption) if token in self.index_to_token]) #iteriere über tokens, mache zu numerics und gebe als Tensor zurück
        
        def numerical_to_matrix(self, numeric_caption:torch.Tensor) -> torch.Tensor:
            """Generiert onehot encodings aus Output von `numericalize`
            Args:
                numeric_caption (torch.Tensor): Tensor, welcher durch 'numericalize' erstellt wurde

            Returns:
                Onehot encoding als Tensor
            """

            return torch.zeros((list(numeric_caption.shape) + [len(self)])).scatter_(len(numeric_caption.shape), numeric_caption.unsqueeze(len(numeric_caption.shape)).type(torch.int64), 1)
        
        def oneHot_sequence_to_tokens(self, sequence:torch.Tensor):
            """Inverse von `numerical_to_matrix`
            Args:
                sequence (torch.Tensor): onehot encodings

            Returns:
                liste von tokens als string
            """
            
            argmax = sequence.argmax(dim=1).cpu().numpy()

            tokens = []
            for arg in argmax:
                tokens.append(self.index_to_token[arg])

            return tokens
        
        def numerical_to_tokens(self, sequence:torch.Tensor):
            """ Inverse von `numericalize`
            Args:
                sequence (torch.Tensor): sequenz von numerischen tokens (output von `oneHot_sequence_to_tokens`)

            Returns:
                liste von tokens als string
            """
            
            tokens = []
            for arg in sequence.numpy():
                tokens.append(self.index_to_token[arg])

            return tokens

        def __len__(self):
            """Länge des vocabs
            """
            return len(self.index_to_token)
            
        @staticmethod
        def tokenize_text(text:str) -> list:
            """Konvertiert text zu einer liste aus tokens
            """
            return [token for token in text.lower().translate(str.maketrans("", "", string.punctuation)).strip().split(" ")]

    class FlickrDataset(Dataset):
        def __init__(self, img_root_dir:str, captions_file:str, img_transform=None, train_amount:float=0.8, valid_amount:float=0.1, split_random_state:int=1234, min_token_count:int=3, train_augmentation_multiplier:int=0) -> None:
            super().__init__()
            self.img_root_dir = img_root_dir #ordner mit bilder
            self.captions_file = captions_file #file mit captions
            self.img_transform = img_transform #transformation der bilder

            self.captions = pd.read_csv(self.captions_file, header=0, names=["img", "caption"]) #lese die captions ein

            #instanziere tokenizer und erstelle vocab
            self.tokenizer = Tokenizer(min_token_count=min_token_count)
            self.tokenizer.build_vocab(self.captions.caption.values)

            self.captions = self.captions.groupby("img").caption.apply(list).to_frame().reset_index() #gruppiere nach den images, da jedes Image mehrere captions besitzt

            #make train validation test split
            self.train_captions = self.captions.sample(frac=train_amount, random_state=split_random_state) #sample train
            self.valid_caption = self.captions.drop(self.train_captions.index).sample(frac=1/(1-train_amount)*valid_amount, random_state=split_random_state) #sample validation
            self.test_caption = self.captions.drop(self.train_captions.index).drop(self.valid_caption.index) #sample test

            #totale länge
            tot_len = len(self.train_captions) + len(self.valid_caption) + len(self.test_caption)

            print(f"Dataset was fractioned into:\n- train: {len(self.train_captions)/tot_len}\n- validation: {len(self.valid_caption)/tot_len}\n- test: {len(self.test_caption)/tot_len}")
            
            self.captions = self.train_captions.explode("caption") #überschreibe captions mit trainings captions

            self.use_augmentation = False #init with false

            #if augmentation enabled
            if train_augmentation_multiplier > 0:
                self.captions = pd.concat([self.captions]*train_augmentation_multiplier, axis=0) #duplicate dataframe n times
                self.use_augmentation = True
                print(f"Data augmentation on training dataset enabled. Trainset is now {train_augmentation_multiplier} times its original size")

            self.is_test = False #only set if test

        def get_validation(self):
            ds_c = copy.copy(self) #make a copy of itself
            ds_c.captions = self.valid_caption.explode("caption") #overrdie captions with valid_caption explosion
            ds_c.use_augmentation = False #disable augmentation

            return ds_c
        
        def get_test(self):
            ds_c = copy.copy(self) #make a copy of itself
            ds_c.captions = self.test_caption #overrdie captions with test_captions
            ds_c.use_augmentation = False #disable augmentation

            ds_c.is_test = True #set test to true

            return ds_c
                
        def augmentate_image(self, img:np.array):
            """Augmentiert bild mit zufälligen verschiedenen transformationen wie (blur, flip, rotation, pixel dropping, hue & saturation)
            """
            aug = iaa.Sequential([
                iaa.Sometimes(0.5, iaa.GaussianBlur(sigma=(0, 3.0))),
                iaa.Fliplr(0.5),
                iaa.Affine(rotate=(-20, 20), mode='symmetric'), 
                iaa.Sometimes(0.5,
                            iaa.OneOf([iaa.Dropout(p=(0, 0.1)),
                                        iaa.CoarseDropout(0.1, size_percent=0.5)])),
                iaa.AddToHueAndSaturation(value=(-10, 10), per_channel=True)
            ])

            return aug.augment_image(img) 

        def __len__(self):
            return len(self.captions)

        def __getitem__(self, idx):
            if torch.is_tensor(idx):
                idx = idx.tolist()

            #get caption(s) and image id from dataframe
            img_id, caption = self.captions.iloc[idx]

            #read image from storage
            img = Image.open(os.path.join(self.img_root_dir, img_id))

            #if augmentation enabled -> do augmentation
            if self.use_augmentation:
                img = Image.fromarray(self.augmentate_image(np.array(img)))
            
            #if transformation enabled
            if self.img_transform is not None:
                img = self.img_transform(img)

            if self.is_test: #caption is a list of captions
                return img, caption
            
            else: #it is a single caption
                numericalized_caption = torch.cat((torch.tensor([self.tokenizer.tok_to_index["<SOS>"]]), self.tokenizer.numericalize(caption), torch.tensor([self.tokenizer.tok_to_index["<EOS>"]])))
                return img, numericalized_caption
        
    class Collate:
        def __init__(self, dataset:FlickrDataset):
            self.dataset = dataset
            self.pad_idx = self.dataset.tokenizer.tok_to_index["<PAD>"]

        def __call__(self, batch):
            imgs = [item[0].unsqueeze(0) for item in batch]
            imgs = torch.cat(imgs, dim=0)

            targets = [item[1] for item in batch]
            targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx) #füge paddings hinzu

            targets = self.dataset.tokenizer.numerical_to_matrix(targets) #sequenz aus tokens (ids) zu onehot

            return imgs, targets
        
    class Collate_Test:
        def __init__(self, dataset:FlickrDataset):
            self.dataset = dataset
            self.pad_idx = self.dataset.tokenizer.tok_to_index["<PAD>"]

        def __call__(self, batch):
            imgs = [item[0].unsqueeze(0) for item in batch]
            imgs = torch.cat(imgs, dim=0)

            targets = [item[1] for item in batch] #verändere targets nicht

            return imgs, targets
    
    def get_loader(img_root_dir, captions_file, img_transform=None, batch_size=64, train_amount=0.8, valid_amount=0.1, min_token_count=3, train_augmentation_multiplier=0, num_workers_train:int=16):
        #erstelle datenset
        train_dataset = FlickrDataset(img_root_dir, captions_file, img_transform, train_amount=train_amount, valid_amount=valid_amount, min_token_count=min_token_count, train_augmentation_multiplier=train_augmentation_multiplier)
        
        #hole validations dataset
        valid_dataset = train_dataset.get_validation()

        #hole test datenset
        test_dataset = train_dataset.get_test()

        #erstelle dataloaders
        train_loader = DataLoader(
            dataset=train_dataset,
            batch_size=batch_size,
            shuffle=True,
            pin_memory=True,
            collate_fn=Collate(train_dataset),
            num_workers=num_workers_train,
            persistent_workers=True
        )

        validation_loader = DataLoader(
            dataset=valid_dataset,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=True,
            collate_fn=Collate(valid_dataset),
        )

        test_loader = DataLoader(
            dataset=test_dataset,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=True,
            collate_fn=Collate_Test(test_dataset),
        )

        return train_loader, validation_loader, test_loader, train_dataset, valid_dataset, test_dataset

#temporary outsource dataset stuff into file (otherwise num_workers of trainloader wont't work)
with open(f'./tmp_ksdbf97skd.py', 'w') as file:
    file.write("\n".join([line[4:] for line in inspect.getsource(tmp).split("\n")[1:]]))

from tmp_ksdbf97skd import Tokenizer, FlickrDataset, Collate, Collate_Test, get_loader #use outsourced function and classes

In [3]:
class RNN(nn.Module):
    def __init__(self, embedding_size:int, hidden_size:int, vocab_size:int, dropout:float=0, num_layers:int=2) -> None:
        super().__init__()
        
        self.embedding = nn.Linear(vocab_size, embedding_size) #embeddings stored here
        self.lstm = nn.LSTM(embedding_size, hidden_size, dropout=dropout, num_layers=num_layers, batch_first=True) #lstm
        self.fc = nn.Linear(hidden_size, vocab_size) #final output

    def forward(self, image_embedding:torch.Tensor, onehot_words:torch.Tensor):
        """ Forward propagation
        Args:
            image_embedding (torch.Tensor): embeddign für Bild (dim: batchsize x embedding_size)
            onehot_words (torch.Tensor): sequenz für onehot tokens (dim: batchsize x sequence len x vocab_size)

        Returns:
            output der Sequenz mit dimensionen (batchsize x 1 [wegen Bild] + sequence len x embedding size)
        """

        image_embedding = image_embedding.unsqueeze(1) #add a sequence dimension
        word_embeddings = self.embedding(onehot_words) #embed onehot words

        embeddings = torch.cat((image_embedding, word_embeddings), dim=1) #concat image embedding and word embeddings along sequence dim 

        x, _ = self.lstm(embeddings) #forward durch lstm
        x = self.fc(x) #transformiere zurück von embedding size zu onehot size
        x = F.log_softmax(x, dim=2) #wegen nlloss

        return x

    def sample_greedy(self, image_embedding:torch.Tensor, start_token_onehot:torch.Tensor, stop_token_onehot:torch.Tensor, max_num_words:int=10):
        """Captioning eines Bildes mit greedy Methode. Sobald der stop token gefunden wurde, wird das captioning gestoppt.
        Args:
            image_embedding (torch.Tensor): embedding des Bilds (dim: 1 x embedding_size)
            start_token_onehot (torch.Tensor): one hot encoding des start tokens (dim: onehot size)
            stop_token_onehot (torch.Tensor): one hot encoding des stop tokens (dim: onehot size)
            max_num_words (int): maximale länge der caption
        """

        with torch.no_grad():
            prediction = [] #predicted words stored here

            input = image_embedding.unsqueeze(1) #add a sequence dimension

            _, _, c = self._step_next_token(input) #feed image into network (output is not important, long and short term memory is important)

            _, one_hot, c = self._step_next_token(self.embedding(start_token_onehot.unsqueeze(0).unsqueeze(0)), c) #feed embedding of start token (add first two dimension to start token) into network

            #if stop token predicted
            if torch.equal(one_hot[0, 0, :], stop_token_onehot):
                if len(prediction) == 0:
                    return torch.empty([])
            
                return torch.stack(prediction)
        
            prediction.append(one_hot[0, 0, :]) #add first predicted word to prediction list

            for i in range(max_num_words):
                input = self.embedding(one_hot) #embed word

                _, one_hot, c = self._step_next_token(input, c) # predict next token based on last token and last c

                #if stop token predicted
                if torch.equal(one_hot[0, 0, :], stop_token_onehot):
                    if len(prediction) == 0:
                        return torch.empty([])
                
                    return torch.stack(prediction)
                
                prediction.append(one_hot[0, 0, :])

            return torch.stack(prediction)

    def sample_beamSearch(self, image_embedding:torch.Tensor, start_token_onehot:torch.Tensor, stop_token_onehot:torch.Tensor, max_num_words:int=3, beam_size:int=2):
        """Captioning eines Bildes mit beamsearch Methode.
        Args:
            image_embedding (torch.Tensor): embedding des Bilds (dim: 1 x embedding_size)
            start_token_onehot (torch.Tensor): one hot encoding des start tokens (dim: onehot size)
            stop_token_onehot (torch.Tensor): one hot encoding des stop tokens (dim: onehot size)
            max_num_words (int): maximale länge der caption
            beam_size (int): grösse des Beams
        """

        with torch.no_grad():
            global_best_prob = torch.tensor(-torch.inf) #best probability found stored here
            
            input = image_embedding.unsqueeze(1) #add a sequence dimension

            _, _, c = self._step_next_token(input) #feed image into network (output is not important, long and short term memory is important)

            def beamsearch(last_token_oneHot:torch.Tensor, last_prob:float, c:tuple, n:int=0) -> Tuple[List[int], float]:
                nonlocal global_best_prob #global best probability

                n += 1 #add n

                input = self.embedding(last_token_oneHot)#embed

                #forward pass through network
                x, c = self.lstm(input, c) 
                x = x[:, [-1], :]
                x = self.fc(x)
                x = torch.log_softmax(x, dim=2) #log softmax

                best_k_idx = torch.topk(x, beam_size, dim=2, largest=True).indices.flatten() #get indices of best predictions
                best_k_prob = x[0, 0, best_k_idx] #get log probabilities of best predictions

                #iterate over best prediction (idx with corresponding prob)
                best_prediction = None
                for idx, prob_t in torch.stack((best_k_idx, best_k_prob), dim=1):
                    prob_current = last_prob+prob_t #calculate new probability 

                    #stop searching if prob_current already more negative than currently best -> because it can only get more negative or better say worse
                    if prob_current < global_best_prob:
                        continue
                    
                    #if we land here a token found which describes the image better
                    idx_as_onehot = torch.zeros_like(x).scatter_(2, idx.unsqueeze(0).unsqueeze(0).unsqueeze(0).type(torch.int64), 1) #calculate onehot encoded token

                    #if stop token predicted or max number of terms reached
                    if torch.equal(idx_as_onehot[0, 0, :], stop_token_onehot) or n > max_num_words:
                        #new best scorer found
                        global_best_prob = prob_current
                        best_prediction = []

                    else:
                        best_tokens_next = beamsearch(idx_as_onehot, prob_current, c, n) #end token not found yet

                        #only returns != None if a better caption found
                        if best_tokens_next != None:
                            best_prediction = [idx] + best_tokens_next #append token to this caption

                return best_prediction

            best_prediction = beamsearch(start_token_onehot.unsqueeze(0).unsqueeze(0), 0, c, 0) #feed embedding of start token (add first two dimension to start token) into network -> best prediction

            return torch.Tensor(best_prediction).flatten().type(torch.int)

    def _step_next_token(self, input:torch.Tensor, c:tuple=None):
        """ Predicted nächsten token 
        Args:
            input (torch.Tensor): embedded token (dim: 1 x sequence_len x embedding_size)
            c (tuple): hidden state

        Returns:
            (log softmax des predicteten tokens, one hot encoded token, hidden state c)
        """
        with torch.no_grad():
            x, c = self.lstm(input, c) 
            x = x[:, [-1], :]
            x = self.fc(x)
            max_idx = torch.argmax(x, dim=2)
            one_hot = torch.zeros_like(x).scatter_(2, max_idx.unsqueeze(2), 1)

            return F.log_softmax(x, dim=2), one_hot, c

def make_cnn_resnet50(embedding_size:int):
    """Make resnet50 with last layer replaced
    """
    cnn = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

    for param in cnn.parameters():
        param.requires_grad = False

    cnn.fc = nn.Sequential(nn.BatchNorm1d(4*512), nn.Linear(4*512, embedding_size), nn.BatchNorm1d(embedding_size)) #replace last layer

    return cnn

def make_cnn_densenet201(embedding_size:int):
    """Make densenet with last layer replaced
    """
    cnn = densenet201(weights=DenseNet201_Weights.IMAGENET1K_V1)

    for param in cnn.parameters():
        param.requires_grad = False

    cnn.classifier = nn.Sequential(nn.BatchNorm1d(cnn.classifier.weight.shape[1]), nn.Linear(cnn.classifier.weight.shape[1], embedding_size), nn.BatchNorm1d(embedding_size)) #replace last layer

    return cnn

def make_rnn_lstm(embedding_size:int, hidden_size:int, vocab_size:int, dropout:float=0, num_layers:int=2):
    """Make lstm 
    """
    return RNN(embedding_size, hidden_size, vocab_size, dropout, num_layers=num_layers)

def make_rnn_gru(embedding_size:int, hidden_size:int, vocab_size:int, dropout:float=0, num_layers:int=2):
    """Make rnn with lstm replaced with gru
    """
    gru = RNN(embedding_size, hidden_size, vocab_size, dropout, num_layers=num_layers)
    gru.lstm = nn.GRU(embedding_size, hidden_size, dropout=dropout, num_layers=num_layers, batch_first=True) #replace lstm with gru layer
    return gru

In [4]:
class LitModel(pl.LightningModule):
    def __init__(self, train_dataset:FlickrDataset, embedding_size:int=100, hidden_size:int=100, lstm_dropout:float=0, num_lstm_layers:int=2, cnn_type:str="resnet50", rnn_type:str="lstm", test_eval_step_pred_max_sequLen:int=30, test_step_pred_beamsize:int=3, alpha:float=0.0001):
        """Lightning modul des Netzes
        Args:
            train_dataset (FlickrDataset): Datenset (hier wird allerdings nur der Tokenizer benötigt)
            embedding_size (int): grösse der embeddings
            hidden_size (int): grösse der hidden size (cell state / hidden state)
            lstm_dropout (float): dropout probability between lstm layers (if num_lstm_layers > 1)
            num_lstm_layers (int): anzahl lstm layers
            cnn_type (str): typ des image encoders ('resnet50' oder 'densenet201')
            rnn_type (str): typ des rnn ('lstm' oder 'gru')
            test_eval_step_pred_max_sequLen (int): max länge der prediction einer caption während des tests sowie validierung
            test_step_pred_beamsize (int): beam size des beamsearch während des tests
            alpha (float): regularisierung
        """
        
        super().__init__()
        
        self.test_eval_step_pred_max_sequLen = test_eval_step_pred_max_sequLen
        self.test_step_pred_beamsize = test_step_pred_beamsize
        self.tokenizer = train_dataset.tokenizer
        self.tokenizer_min_token_count = self.tokenizer.min_token_count
        self.data_augmentation_enabled = train_dataset.use_augmentation
        self.alpha = alpha

        #load specific cnn
        if cnn_type == "resnet50":
            self.cnn = make_cnn_resnet50(embedding_size)

        elif cnn_type == "densenet201":
            self.cnn = make_cnn_densenet201(embedding_size)
            
        else:
            raise Exception("unknown cnn_type")
        
        #load specific rnn
        if rnn_type == "lstm":
            self.rnn = make_rnn_lstm(embedding_size, hidden_size, len(self.tokenizer), lstm_dropout, num_lstm_layers)

        elif rnn_type == "gru":
            self.rnn = make_rnn_gru(embedding_size, hidden_size, len(self.tokenizer), lstm_dropout, num_lstm_layers)

        else:
            raise Exception("unknown rnn_type")

        self.save_hyperparameters()

    def training_step(self, batch:tuple, batch_idx):
        """Training step für lightning
        """

        loss = self._get_teacherForce_loss(batch) #calculate loss based on batch
         
        self.log("train_loss", loss, on_step=True, on_epoch=True) #log loss

        return loss
    
    def validation_step(self, batch:tuple, batch_idx):
        """Validation step für lightning
        """
        # loss_greedy not logged anymore since the bleu loss doesn't get much better after a couple epochs and uses much compute time
        
        with torch.no_grad():
            loss_tf = self._get_teacherForce_loss(batch) #calculate teacher force log loss based on batch
            #loss_greedy = self._get_greedy_bleu_loss(batch) #calculate greedy log loss based on batch
        
        self.log("val_loss", loss_tf, on_step=True, on_epoch=True) #negative log loss of teacher forcing
        #self.log("val_loss_bleu", loss_greedy, on_step=True, on_epoch=True) #bleu loss of greedy prediction

    def test_step(self, batch:tuple, batch_idx):
        """Test step für lightning
        """
        batch_image, batch_captions = batch #split des inputs

        list_bleu_greedy = [] #hier werden die bleu scores mit greedy für alle bilder des batches abgespeichert
        list_bleu_beamSearch = [] #hier werden die bleu scores mit beamsearch für alle bilder des batches abgespeichert
        #iteriere über batch
        for i in range(len(batch_image)):
            image = batch_image[i] #hole einzelnes bild
            captions = [self.tokenizer.tokenize_text(caption) for caption in batch_captions[i]] #tokenize die einzelnen captions des bildes

            #führe ein captioning des bildes durch. Einmal mit greedy einmal mit beamsearch
            pred_greedy = self.sample_greedy(image, self.test_eval_step_pred_max_sequLen)
            pred_beamsearch = self.sample_beamSearch(image, self.test_eval_step_pred_max_sequLen, self.test_step_pred_beamsize)

            #bleu scores berechnen
            bleu_greedy = sentence_bleu(captions, pred_greedy)
            bleu_beamSearch = sentence_bleu(captions, pred_beamsearch)

            #store in list
            list_bleu_greedy.append(bleu_greedy)
            list_bleu_beamSearch.append(bleu_beamSearch)

        #logge mittelwert
        self.log("test_bleu_greedy", np.mean(list_bleu_greedy), on_step=True, on_epoch=True, batch_size=len(batch_image)) #log bleu greedy
        self.log("test_bleu_beamSearch", np.mean(list_bleu_beamSearch), on_step=True, on_epoch=True, batch_size=len(batch_image)) #log bleu greedy

    def _get_teacherForce_loss(self, batch:tuple):
        """Calculate loss (teacher forced) based on batch$
        Args:
            batch (tuple): output von dataloader
        Returns:
            negative log loss von batch
        """

        #type definition
        batch_images:torch.Tensor
        batch_targets_onehot:torch.Tensor

        batch_images, batch_targets_onehot = batch #split batch data into image and onehot encodings for targets

        image_embedding = self.cnn(batch_images) #get image embeddings for whole batch

        output:torch.Tensor = self.rnn(image_embedding, batch_targets_onehot[:, :-1, :]) #feed image and tokens into rnn (last word is not sent to rnn since this is only the stop token)

        output = output[:, 1:, :] #ignore first sequence from output since this is the prediction if the image
        batch_targets_onehot = batch_targets_onehot[:, 1:, :] #ignore first sequence from tokens since this is the start token

        #remove batch dimension
        batch_targets_onehot = batch_targets_onehot.flatten(end_dim=1) #ignore start token
        output = output.flatten(end_dim=1)
        
        #nll loss with padding tokens ignored
        loss = F.nll_loss(output, batch_targets_onehot.argmax(dim=1), ignore_index=int(self.tokenizer.tok_to_index["<PAD>"]))

        return loss

    def _get_greedy_bleu_loss(self, batch:tuple):
        """Calculate bleu loss (simply 1 - bleu score).
        Args:
            batch (tuple): output von dataloader
        Returns:
            mittlerer bleu loss des batches
        """
        
        batch_images, batch_targets_onehot = batch #split batch data into image and onehot encodings for targets

        image_embedding:torch.Tensor = self.cnn(batch_images) #get image embeddings for whole batch and add a sequence dimension

        #get start and stop token
        start_token_onehot = self.tokenizer.numerical_to_matrix(torch.tensor(self.tokenizer.tok_to_index["<SOS>"])).to(image_embedding.device)
        stop_token_onehot = self.tokenizer.numerical_to_matrix(torch.tensor(self.tokenizer.tok_to_index["<EOS>"])).to(image_embedding.device)

        #iterate over batch
        bleu_losses = [] #bleu losses stored here
        for i in range(len(batch_images)):
            image:torch.Tensor = image_embedding[[i]] #get image from batch
            targets_onehot:torch.Tensor = batch_targets_onehot[i] #get image from batch
            
            #transform onehot tokens back to string tokens (ignore special tokens)
            reference = list(filter(lambda a: a not in ['<SOS>', '<EOS>', '<PAD>'], self.tokenizer.oneHot_sequence_to_tokens(targets_onehot)))
            
            #predict string tokens from greedy prediction
            prediction = self.tokenizer.oneHot_sequence_to_tokens(self.rnn.sample_greedy(image, start_token_onehot, stop_token_onehot, self.test_eval_step_pred_max_sequLen))

            #add bleu loss to list
            bleu_losses.append(1-sentence_bleu([reference], prediction))
        
        return np.mean(bleu_losses) #return mean bleu loss

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), weight_decay=self.alpha)
        return optimizer
        
    def sample_greedy(self, image:torch.Tensor, max_num_words:int=10):
        """ Predict caption mit greedy methode
        Args:
            image (torch.Tensor): Bild als Tensor
            max_num_words (int): maximale anzahl an tokens der prediction

        Returns:
            caption as string tokens
        """
        
        self.eval() #set to evaluation mode

        with torch.no_grad():
            image_embedding:torch.Tensor = self.cnn(image.unsqueeze(0)) #embedd image

            #get start and stop token
            start_token_onehot = self.tokenizer.numerical_to_matrix(torch.tensor(self.tokenizer.tok_to_index["<SOS>"])).to(image_embedding.device)
            stop_token_onehot = self.tokenizer.numerical_to_matrix(torch.tensor(self.tokenizer.tok_to_index["<EOS>"])).to(image_embedding.device)

            #make a greedy prediction
            output = self.rnn.sample_greedy(image_embedding, start_token_onehot, stop_token_onehot, max_num_words)

            #if output is empty
            if len(output.size()) == 0:
                return "<empty>"

            return self.tokenizer.oneHot_sequence_to_tokens(output.cpu()) #convert onehot tokens to string tokens
        
    def sample_beamSearch(self, image:torch.Tensor, max_num_words:int=3, beam_size:int=2):
        """ Predict caption mit beamsearch methode
        Args:
            image (torch.Tensor): Bild als Tensor
            max_num_words (int): maximale anzahl an tokens der prediction
            beam_size (int): beam size

        Returns:
            caption as string tokens
        """

        self.eval() #set to evaluation mode

        with torch.no_grad():
            image_embedding:torch.Tensor = self.cnn(image.unsqueeze(0))

            start_token_onehot = self.tokenizer.numerical_to_matrix(torch.tensor(self.tokenizer.tok_to_index["<SOS>"])).to(image_embedding.device)
            stop_token_onehot = self.tokenizer.numerical_to_matrix(torch.tensor(self.tokenizer.tok_to_index["<EOS>"])).to(image_embedding.device)

            output = self.rnn.sample_beamSearch(image_embedding, start_token_onehot, stop_token_onehot, max_num_words, beam_size)

            #if output is empty
            if len(output) == 0:
                return "<empty>"

            return self.tokenizer.numerical_to_tokens(output)

In [None]:
##load data
batchsize = 64
min_token_count = 3
train_augmentation_multiplier = 6
img_resize = 224
num_workers_train = 8

transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((img_resize, img_resize), antialias=True), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_loader, validation_loader, test_loader, train_dataset, valid_dataset, test_dataset = get_loader(r"C:\Users\tobia\archive\Images", r"C:\Users\tobia\archive\captions.txt", transform, batchsize, min_token_count=min_token_count, train_augmentation_multiplier=train_augmentation_multiplier, num_workers_train=num_workers_train)

In [None]:
##load model from checkpoint
model_run_id = "model-if28cok5:best"
artifact_dir = wandb.Api().artifact(f"{user}/{project}/{model_run_id}").download()

# load checkpoint
model = LitModel.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")

In [None]:
##train model in a single run

#create fresh model
embedding_size = 512
hidden_size = 1024
lstm_dropout = 0.5
lstm_num_layers = 3
cnn_type = "densenet201"
rnn_type = "lstm"
test_step_pred_max_sequLen = 30
test_step_pred_beamsize = 3
alpha = 0.0001

model = LitModel(train_dataset, embedding_size, hidden_size, lstm_dropout, lstm_num_layers, cnn_type, rnn_type, test_step_pred_max_sequLen, test_step_pred_beamsize, alpha)

#start training
logger = WandbLogger(project=project, log_model="all")
logger.watch(model, log="all")

checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min")
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=2, verbose=False, mode="min")

trainer = pl.Trainer(devices=1, accelerator="gpu", logger=logger, callbacks=[early_stop_callback, checkpoint_callback], log_every_n_steps=50)
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=validation_loader)
trainer.test(ckpt_path="best", dataloaders=test_loader)

best_val_loss = early_stop_callback.state_dict()["best_score"] #get best validation loss from early stopper

wandb.log({"best_val_loss":best_val_loss}) #log best validation loss

wandb.finish()

In [None]:
##train model in a sweep

def sweep_iteration():
    wandb.init()

    #create fresh model
    embedding_size = wandb.config.embedding_size
    hidden_size = wandb.config.hidden_size
    lstm_dropout = wandb.config.lstm_dropout
    lstm_num_layers = wandb.config.lstm_num_layers
    cnn_type = wandb.config.cnn_type
    rnn_type = wandb.config.rnn_type
    test_step_pred_max_sequLen = 30
    test_step_pred_beamsize = 3
    alpha = wandb.config.alpha

    model = LitModel(train_dataset, embedding_size, hidden_size, lstm_dropout, lstm_num_layers, cnn_type, rnn_type, test_step_pred_max_sequLen, test_step_pred_beamsize, alpha)

    #start training
    logger = WandbLogger(project=project, log_model="all")
    logger.watch(model, log="all")

    checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min")
    early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=2, verbose=False, mode="min")

    trainer = pl.Trainer(devices=1, accelerator="gpu", logger=logger, callbacks=[early_stop_callback, checkpoint_callback], log_every_n_steps=50)
    trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=validation_loader)
    trainer.test(ckpt_path="best", dataloaders=test_loader)

    best_val_loss = early_stop_callback.state_dict()["best_score"] #get best validation loss from early stopper

    wandb.log({"best_val_loss":best_val_loss}) #log best validation loss

    wandb.finish()

sweep_config = {
    "method": "grid",
    "name": "test_sweep",
    "metric": {
        "goal": "minimize",
        "name": "best_val_loss"
    },
    "parameters":{
        "embedding_size": {"values":[16, 128, 512]},
        "hidden_size": {"values":[16, 128, 512]},
        "lstm_dropout": {"values":[0.5]},
        "lstm_num_layers": {"values":[3]},
        "cnn_type": {"values":["resnet50", "densenet201"]},
        "rnn_type": {"values":["lstm"]},
        "alpha": {"values":[0.0001]}
    }
}

sweep_id = wandb.sweep(sweep_config, project=project)
wandb.agent(sweep_id, sweep_iteration)

In [None]:
#show multiple images

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

ds = test_dataset #choose dataset

max_num_words = 30
for i in range(0, 100):
    image, captions = ds.__getitem__(i) #get item

    imshow(image)
    print("Reference captions:\n", "\n".join([str(ds.tokenizer.tokenize_text(caption)) for caption in captions]))

    print("greedy caption: ", model.sample_greedy(image.cuda(), max_num_words=max_num_words))
    print("beamsearch caption: ", model.sample_beamSearch(image.cuda(), max_num_words=max_num_words, beam_size=5))

In [None]:
#show single image

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

ds = test_dataset #choose dataset

image, captions = ds.__getitem__(18) #get item

imshow(image)
print("Reference captions:\n", "\n".join([str(ds.tokenizer.tokenize_text(caption)) for caption in captions]))

max_num_words = 30
print("greedy caption: ", model.sample_greedy(image.cuda(), max_num_words=max_num_words))
print("beamsearch caption: ", model.sample_beamSearch(image.cuda(), max_num_words=max_num_words, beam_size=5))

In [None]:
#custom image

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

root = tkinter.Tk()
root.withdraw()
file_path = tkinter.filedialog.askopenfilename()
root.destroy()

image = Image.open(file_path)
image = transform(image)

imshow(image)

max_num_words = 30
print("greedy caption: ", model.sample_greedy(image.cuda(), max_num_words=max_num_words))
print("beamsearch caption: ", model.sample_beamSearch(image.cuda(), max_num_words=max_num_words, beam_size=5))