## Setup, Installations, and Imports

In [None]:
#install packages as necessary
!pip -q install accelerate datasets transformers evaluate jiwer

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
#import our packages
import os
import torch
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
from datasets import load_metric
from transformers import AutoImageProcessor, AutoModelForImageClassification, Trainer, TrainingArguments, BlipProcessor, BlipForConditionalGeneration
import matplotlib.pyplot as plt
import accelerate
import numpy as np
import pickle
import random
import tqdm
import requests
import evaluate
from PIL import Image
import copy
from rich.progress import track

In [None]:
#mount the drive (for saving files to google drive (so the data will not have to be reprocessed everytime colab restarts a session))
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


## Poliphonic Data Extraction and Preprocessing

In [None]:
'''
Note: 
- This section is just meant to extract the data from the original tgz file, process it into a Poliphonic dataset object, and save it accordingly (such that it can persist beyond individual Google Colab sessions)
If all of the above has already been done, skip this section.
- This section was originally a part of code for another model we were experimenting with. Not all of it is necessary for the finetuned blip model but since
the datasets were already created anyway for prior experiments, we leave it and leverage a piece of it for the finetuned blip model
'''

In [None]:
#this specific cell block is based on the original paper that created the raw dataset we're using.
#it's been modified/adapted for our use case but credit goes to that paper and its authors. It provides some processing of bekrn files
#that has been helpful for our group, since we are not familiar with the notation or the predecessor notation it was based on

def load_data(partition_file, resize_ratio=1., use_raw_krn=False, load_distorted=False, extension=".bekrn"):
    X = []
    Y = []
    with open(partition_file) as partfile:
        part_lines = partfile.read()
        part_lines = part_lines.split("\n")
        for file_path in track(part_lines, description="Loading..."):
            if extension != ".bekrn":
                file_path = file_path.replace(".bekrn", extension)
            krn = None
            krnlines = []
            file_path = f"{file_path}"
            if os.path.isfile(file_path):
                with open(file_path) as krnfile:
                    try:
                      krn = krnfile.read()
                    except:
                      continue
                    krn = krn.replace(" ", " <s> ")
                    krn = krn.replace("·", " ")
                    lines = krn.split("\n")
                    for line in lines:
                        line = line.replace("\t", " <t> ")
                        line = line.split(" ")
                        if len(line) > 1:
                            line.append("<b>")
                            krnlines.append(line)
                    if os.path.exists(f"{file_path.split('.')[0]}.jpg"):
                        if load_distorted:
                            height = 256
                            img = cv2.imread(f"{file_path.split('.')[0]}_distorted.jpg", 0)
                            width = int(float(height * img.shape[1]) / img.shape[0])
                            img =  cv2.resize(img, (width, height), interpolation=cv2.INTER_LINEAR)
                            if (height//8) * (width//16) > len(sum(krnlines, [])):
                                width = int(np.ceil(img.shape[1] * resize_ratio))
                                height = int(np.ceil(img.shape[0] * resize_ratio))
                                img = cv2.resize(img, (width, height), interpolation=cv2.INTER_LINEAR)
                                img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
                                X.append(img)
                                Y.append(sum(krnlines, []))
                        else:
                            img = cv2.imread(f"{file_path.split('.')[0]}.jpg", 0)
                            width = int(np.ceil(img.shape[1] * resize_ratio))
                            height = int(np.ceil(img.shape[0] * resize_ratio))
                            img = cv2.resize(img, (width, height), interpolation=cv2.INTER_LINEAR)
                            img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
                            X.append(img)
                            Y.append(sum(krnlines, []))
    print(len(X)), print(len(Y))
    return X, Y


class PoliphonicDataset(Dataset):
    def __init__(self, partition_file) -> None:
        self.x, self.y = load_data(partition_file)
        self.tensorTransform = transforms.ToTensor()

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        image = self.tensorTransform(self.x[index])
        gt = torch.from_numpy(np.asarray([self.w2i[token] for token in self.y[index]]))

        return image, gt, (image.shape[2] // 8) * (image.shape[1] // 16), len(gt)

    def get_max_hw(self):
        m_width = np.max([img.shape[1] for img in self.x])
        m_height = np.max([img.shape[0] for img in self.x])

        return m_height, m_width

    def get_max_seqlen(self):
        return np.max([len(seq) for seq in self.y])

    def vocab_size(self):
        return len(self.w2i)

    def get_gt(self):
        return self.y

    def set_dictionaries(self, w2i, i2w):
        self.w2i = w2i
        self.i2w = i2w
        self.padding_token = w2i['<pad>']

    def get_dictionaries(self):
        return self.w2i, self.i2w

    def get_i2w(self):
        return self.i2w

In [None]:
#unzip the tgz raw data file.
#note: change the below string paths as appropriate.

!tar -xzf "grandstaff.tgz" -C "grandstaff_dataset/"

tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.quarantine'
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.quarantine'
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.quarantine'
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.lastuseddate#PS'
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.lastuseddate#PS'
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.lastuseddate#PS'
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.lastuseddate#PS'
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.quarantine'
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.lastuseddate#PS'
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.quarantine'
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.quarantine'
tar: Ignoring unknown extende

In [None]:
#Set up the train, val, and test sets as Dataset objects
train_path = "drive/MyDrive/train_shuffled.txt"
val_path = "drive/MyDrive/val_shuffled.txt"
test_path = "drive/MyDrive/test_shuffled.txt"

#create the PoliphonicDataset objects
train_dataset = PoliphonicDataset(partition_file=train_path)
val_dataset = PoliphonicDataset(partition_file=val_path)
test_dataset = PoliphonicDataset(partition_file=test_path)


In [None]:
#double check that the lengths of the datasets we just read in are as expected
len(train_dataset), len(test_dataset), len(val_dataset)

(32330, 10776, 10776)

## Create the train, test, and val files

In [None]:
'''
Note: This section is only for denominating which samples will go into train vs test vs val datasets. If that's already been determined by some other method,
skip this section.
'''

In [None]:
#let's get the total number of samples from the dataset first
sample_bekrn = [] #here, we're only interested in the bekern files (vs the krn files)

for root, dirs, files in os.walk("grandstaff_dataset/"):
   for name in files:
      if 'bekrn' in name:
        print(os.path.join(root, name))
        sample_bekrn.append(os.path.join(root, name))

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
grandstaff_dataset/chopin/mazurkas/mazurka33-2/min3_up_m-16-20.bekrn
grandstaff_dataset/chopin/mazurkas/mazurka33-2/maj2_down_m-8-12.bekrn
grandstaff_dataset/chopin/mazurkas/mazurka33-2/min3_down_m-84-87.bekrn
grandstaff_dataset/chopin/mazurkas/mazurka33-2/min3_down_m-21-24.bekrn
grandstaff_dataset/chopin/mazurkas/mazurka33-2/maj2_up_m-18-21.bekrn
grandstaff_dataset/chopin/mazurkas/mazurka33-2/maj3_down_m-85-90.bekrn
grandstaff_dataset/chopin/mazurkas/mazurka33-2/min3_down_m-126-129.bekrn
grandstaff_dataset/chopin/mazurkas/mazurka33-2/maj2_down_m-116-120.bekrn
grandstaff_dataset/chopin/mazurkas/mazurka33-2/min3_down_m-96-99.bekrn
grandstaff_dataset/chopin/mazurkas/mazurka33-2/maj3_up_m-128-132.bekrn
grandstaff_dataset/chopin/mazurkas/mazurka33-2/maj2_up_m-54-57.bekrn
grandstaff_dataset/chopin/mazurkas/mazurka33-2/min3_down_m-30-33.bekrn
grandstaff_dataset/chopin/mazurkas/mazurka33-2/original_m-128-132.bekrn
grandstaff_dat

In [None]:
len(sample_bekrn)

In [None]:
#let's split the train-val-test sets such that each are 60%, 20%, and 20% of all the 53889 samples we have
train_pt, val_pt = len(sample_bekrn)*0.6, len(sample_bekrn)*0.8
train_pt, val_pt

(32333.399999999998, 43111.200000000004)

In [None]:
#do the split and double check that it works as expected

#shuffle the list beforehand just in case so there isn't any weird slant in the data between train vs. val vs. test (since it is possible that
#beethoven snippets or diff from, say, scarletti)
random.seed(0) #seed for reproducibility
random.shuffle(sample_bekrn)

#split the data into different sets
train, val, test = sample_bekrn[:int(train_pt)], sample_bekrn[int(train_pt):int(val_pt)], sample_bekrn[int(val_pt):]

#and double check the size
len(train), len(val), len(test)

(32333, 10778, 10778)

In [None]:
#save the different dataset records to text files
def write_testset_files(data, file_name):
  '''
  Save the given data in a file determined by the file_name. Samples are separated by a new line.
  '''
  with open(f'{file_name}', 'w') as f:
      for samp in data:
          f.write(f"{samp}\n")

In [None]:
#save it for use later just in case
write_testset_files(train, 'drive/MyDrive/train_shuffled.txt')
write_testset_files(test, 'drive/MyDrive/test_shuffled.txt')
write_testset_files(val, 'drive/MyDrive/val_shuffled.txt')

## How do multimodals do out of the box?

In [None]:
### let's test a popular vision-language model from salesforce called blip.
### It seems to mostly be used for tasks like image-captioning, but mayeb we can train it to treat the sequence of musical notes as the 'caption'?

In [None]:
#this cell block is from huggingface documentation for BLIP
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda")

img_url = 'grandstaff_dataset/beethoven/piano-sonatas/sonata31-2/maj2_down_m-1-5.jpg' #a random jpg image from our data
raw_image = Image.open(img_url).convert('RGB')

# conditional image captioning
text = "a list of music symbols starting with the symbol for " #i tested a number of different leading texts here but none worked well
inputs = processor(raw_image, text, return_tensors="pt").to("cuda", torch.float16)

out = model.generate(**inputs)
print(processor.decode(out[0], skip_special_tokens=True))

# unconditional image captioning
inputs = processor([raw_image, raw_image], return_tensors="pt").to("cuda", torch.float16) #sent the raw_image in twice just to sanity check that batched outputs are in the shape and have the content I expect

out = model.generate(**inputs)
print(processor.decode(out[0], skip_special_tokens=True))

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


a list of music symbols starting with the symbol for the letter f
a black and white image of a musical score with notes


In [None]:
## the above is not really unreasonable but it's not useful nor specific enough. let's try finetuning to see if the model can do better?

## Finetuning BLIP

In [None]:
torch.cuda.empty_cache() #run to help free up a bit of memory if i'm out of it

In [None]:
#ok, let's try finetuning and see if we can get better results that way
class blip_music_dataset(torch.utils.data.Dataset):
    '''
    Make a music dataset object that can be then used as input for the BLIP model.
    '''
    def __init__(self, data_path, processor, poliphonic_formatted_gt, transforms=None):
        '''
        Constructor class.
        Args:
          data_path (str): the path to the data for which we want to create a Dataset object
          processor (obj): the img-text processor for inputs that are intended for blip
          poliphonic_formatted_gt (list): the list of processed ground truth text labels (processed by the Poliphonic dataset beforehand)
          transforms (dict): optional image transformations applied before the image is sent to the processor and blip
        '''
        #set up the inputs
        self.data_path=data_path
        self.processor=processor
        self.gt = poliphonic_formatted_gt
        if transforms:
          self.transforms = transforms
        else:
          self.transforms = None

        #read in the sample file paths in the data file
        self.sample_paths = []
        with open(self.data_path, 'r') as f:
          for line in f:
            self.sample_paths.append(line)

        #run a sweep to make sure there are no paths that coudl cause errors later
        self.check_exists()

    def check_exists(self):
      '''
      With this particular grandstaff_dataset source, there appear to be a few missing image files (i.e. bekrn exists but no jpg.).
      This function roots them out so as not to cause errors when __getitem__ is called.
      '''
      with open(self.data_path) as f:
        for line in f:
          new = line.split('.')[0]+'.jpg'
          if os.path.exists(new) == False:
            print(f"Could not locate image file for: {new}")
            self.sample_paths.remove(line)

    def __len__(self):
      '''
      Get the number of samples in the dataset.

      Returns:
        int: the length/size of the dataset
      '''
      return len(self.sample_paths)

    def __getitem__(self, idx):
      '''
      Retrieve processed sample from dataset

      Args:
        idx (int): the index of the sample within the dataset
      Returns:
        encoding (dict): dictionary of information about the dataset, including pixel_values, labels, attention_mask
      '''

      #get the image pixels
      img_path = self.sample_paths[idx].split('.')[0] + '.jpg'
      raw_image = Image.open(img_path).convert('RGB')
      if self.transforms:
        raw_image = self.transforms(raw_image)

      #labels
      labels = "".join(self.gt[idx])

      #send it to the blip processor (due to model + gpu constraints, I'm going to allow for truncation)
      encoding = self.processor(images=raw_image, text=labels, max_length=None, padding='max_length', truncation=True, return_tensors="pt")

      # remove batch dimension (this was taken from huggingface documentation)
      encoding = {k:v.squeeze() for k,v in encoding.items()}

      return encoding

In [None]:
#make datasets and dataloaders. I kept the batch size=2 because I kept running out of cuda memory later on in the training loop if batch size > 2
train_dataset_blip = blip_music_dataset('drive/MyDrive/train_shuffled.txt', processor, train_dataset.get_gt(), transforms=None)
train_dataloader_blip = DataLoader(train_dataset_blip, shuffle=True, batch_size=2)

val_dataset_blip = blip_music_dataset('drive/MyDrive/val_shuffled.txt', processor, val_dataset.get_gt(), transforms=None)
val_dataloader_blip = DataLoader(val_dataset_blip, shuffle=True, batch_size=2)

test_dataset_blip = blip_music_dataset('drive/MyDrive/test_shuffled.txt', processor, test_dataset.get_gt(), transforms=None)
test_dataloader_blip = DataLoader(test_dataset_blip, shuffle=True, batch_size=2)


Could not locate image file for: grandstaff_dataset/mozart/piano-sonatas/sonata15-2/.jpg
Could not locate image file for: grandstaff_dataset/mozart/piano-sonatas/sonata03-1/.jpg
Could not locate image file for: grandstaff_dataset/hummel/preludes/prelude67-01/.jpg
Could not locate image file for: grandstaff_dataset/hummel/preludes/prelude67-01/.jpg
Could not locate image file for: grandstaff_dataset/chopin/mazurkas/mazurka30-4/.jpg
Could not locate image file for: grandstaff_dataset/mozart/piano-sonatas/sonata06-3e/.jpg
Could not locate image file for: grandstaff_dataset/chopin/mazurkas/mazurka41-2/.jpg


In [None]:
#note: it looks like from above that the extra '/' before '.jpg' is causing the path error but even after removing those, I could not find any associated image files
#therefore, I'm just removing them since there aren't a lot (7 bekrn paths that caused errors)
#as an example:
os.path.exists("grandstaff_dataset/chopin/mazurkas/mazurka41-2.jpg")

False

In [None]:
#double check what the format of a sample looks like + that it looks as expected, esp. if it was just read in in this session
test = next(iter(train_dataloader_blip))['input_ids']
test

tensor([[ 101, 1008, 1008,  ..., 1028, 1012,  102],
        [ 101, 1008, 1008,  ..., 2497, 1026,  102]])

Now let's load in the metric we want to use later: Word Error Rate.

This (roughly) measures the number of "errors" -- where errors is defined as differences between what was predicted and the ground truth reference --
(e.g., deletions, substitutions, insertions, etc.) divided by the number of words in the reference. The higher the WER value, the worse the match; WER = 0 indicates a perfect generation/prediction.

We picked this metric over others for many reasons. It is sensitive to exact matches of words as well as the order of the words (which rouge-n, another popular image-captioning metric, does not always appear to be), and allows for more reasonable flexibility than other metrics like f1/precision/recall for our use case.

In [None]:
wer = evaluate.load('wer')

Note: to avoid unfairly penalizing the model during WER eval, we put in a section below to truncate the ground truth label to 500 tokens if need be
for the purposes of wer evaluation only. 

In [None]:
#setting up our optimizer with 5e-5 learning rate
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

#prep the datasets that will be used/seen during the training loop
name_datasets = ['train', 'val']
datasets = [train_dataloader_blip, val_dataloader_blip]

#set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

#initialize the metrics
best_wer = 1.0 #wer = 1 is the worst possible score; wer = 0 is the best possible score

#let's run a simple training loop (not putting this in a function since I'm not really running this for a bunch of different datasets)
for epoch in range(2): #keeping the epoch numbers small because performance seemed far worse when we previously trained beyond a few epochs

  print("Epoch:", epoch)

  for name, data in zip(name_datasets, datasets):

    if name == 'train':
      model.train() #set model to training mode
    else:
      model.eval() #set model to eval mode


    #keep a higher level list of running metric values
    rwer = 0
    epoch_agg_metrics = {}

    for idx, batch in enumerate(tqdm.tqdm(data)):

      #get the specific pieces of information we need from each sample
      input_ids = batch.pop("input_ids").to(device)
      pixel_values = batch.pop("pixel_values").to(device)
      attention_mask = batch.pop("attention_mask").to(device)

      optimizer.zero_grad() #zero out the gradient

      with torch.set_grad_enabled(name == 'train'):

        #send the samples to the model
        outputs = model(input_ids=input_ids,
                        pixel_values=pixel_values,
                        labels=input_ids,
                        attention_mask=attention_mask)

        #get and print loss
        loss = outputs.loss
        print(f"{name} loss:", loss.item())

        #backprop
        if name == 'train':
          loss.backward()
          optimizer.step()

      #check how the current model at this stage of training fares on wer metrics
      generated_ids = model.generate(pixel_values=pixel_values, max_new_tokens=500) #received warnings about degraded performance when I set max_new_tokens beyond 500, so keeping 500 here.
      generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)
      gt_caption = processor.batch_decode([i[:500] for i in input_ids], skip_special_tokens=True) #truncating tokens to 500 so model is not unfairly penalized for 'deleted'/missing words when/if ground truth tokens exceed 500

      if idx % 100 == 0:
        print(f"This was the generated caption: {generated_caption}")
        print(f"This was the true caption: {gt_caption}")

      wer_score = wer.compute(predictions=generated_caption, references=gt_caption)
      rwer += wer_score

      #calculate the agg metrics for every 100 step interval. Logging this every 100 steps for the same reason listed below.
      if idx % 100 == 0:
        agg_wer = rwer/(idx+1)
        print(f"Step {idx+1} -- wer: {agg_wer}")


      #additional, more frequent checkpointing in case colab sessions time out
      if idx%1000 == 0:
        if agg_wer < best_wer:
          best_wer = agg_wer
          best_model_wts_wer = copy.deepcopy(model.state_dict())
          torch.save(best_model_wts_wer, f'drive/MyDrive/fine_tuned_blip_wer_{name}_idx_{idx}')

      #keep track of how well the model is doing on wer
      #(note: the training dataset is massive and BLIP is relatively sensitive/response to samples, in my opinion (i.e., within a couple hundred
      #steps, we can see pretty noticeable changes in training loss, wer scores, and outputted text). So I'm logging the wer metrics
      #and saving the weights every 1000 steps instead of every epoch.)
      if name == 'val' and idx%1000 == 0:
        if agg_wer < best_wer: #the lower the score, the better
          best_wer = agg_wer
          best_model_wts_wer = copy.deepcopy(model.state_dict())
          torch.save(best_model_wts_wer, 'drive/MyDrive/fine_tuned_blip_wer_WEIGHTS_ONLY')

print(f"BEST WER: {best_wer}")

Epoch: 0


  0%|          | 0/16165 [00:00<?, ?it/s]

train loss: 6.785060882568359


  0%|          | 1/16165 [00:01<8:13:39,  1.83s/it]

This was the generated caption: ['there are many birds sitting on a wire with music notes', 'arafed wire with birds sitting on it']
This was the true caption: ['* * ekern _ 1. 0 < t > * * ekern _ 1. 0 < b > * cleff4 < t > * clefg2 < b > * k [ b - ] < t > * k [ b - ] < b > * m3 / 4 < t > * m3 / 4 < b > = - < t > = - < b > 4cc < t > 8. ddl < b >. < t > 16ccjk < b > 4c < s > 4e < s > 4b - < t > 4b - < b > 4c < s > 4e < s > 4b - < t > 8al < b >. < t > 8gj < b > = < t > = < b > 4ff < t > 8fl < b >. < t > 8gj < b > 4c < s > 4f < s > 4a < t > 4a < b > 4c < s > 4f < s > 4a < t > 8al < b >. < t > 8b - j < b > = < t > = < b > 4cc < t > 8ddl < b >. < t > 16r < b >. < t > 16ccjk < b > 4c < s > 4e < s > 4b - < t > 4b - < b > 4c < s > 4e < s > 4b - < t > 8al < b >. < t > 8gj < b > = < t > = < b > * ^ < t > * < b > 4ff < < t > 4r < t > 8fl < b >. < t >. < t > 8gj < b > 4f < s > 4a < t > 8cl < t > 2a < b >. < t > 8c # j < t >. < b > 4r < t > 4d < s > 4f < s > 4a < t >. < b > * v < t > * v < t > * < b 

  0%|          | 2/16165 [00:03<8:11:10,  1.82s/it]

train loss: 6.565906524658203


  0%|          | 3/16165 [00:05<8:05:50,  1.80s/it]

train loss: 5.567357063293457


  0%|          | 4/16165 [00:07<8:09:52,  1.82s/it]

train loss: 6.604998588562012


  0%|          | 5/16165 [00:09<8:13:59,  1.83s/it]

train loss: 5.291231155395508


  0%|          | 6/16165 [00:10<8:12:16,  1.83s/it]

train loss: 5.168936729431152


  0%|          | 7/16165 [00:12<8:13:14,  1.83s/it]

train loss: 5.511659622192383


  0%|          | 8/16165 [00:14<8:13:35,  1.83s/it]

train loss: 4.920980453491211


  0%|          | 9/16165 [00:16<8:12:47,  1.83s/it]

train loss: 4.43996000289917


  0%|          | 10/16165 [00:18<8:12:35,  1.83s/it]

train loss: 4.05305290222168


  0%|          | 11/16165 [00:20<8:13:11,  1.83s/it]

train loss: 4.8976874351501465


  0%|          | 12/16165 [00:21<8:12:05,  1.83s/it]

train loss: 4.694485664367676


  0%|          | 13/16165 [00:23<8:11:04,  1.82s/it]

train loss: 4.11912202835083


  0%|          | 14/16165 [00:25<8:11:38,  1.83s/it]

train loss: 4.386322021484375


  0%|          | 15/16165 [00:27<8:14:27,  1.84s/it]

train loss: 7.187243461608887


  0%|          | 16/16165 [00:29<8:17:54,  1.85s/it]

train loss: 5.810055732727051


  0%|          | 17/16165 [00:31<8:20:42,  1.86s/it]

train loss: 3.824472188949585


  0%|          | 18/16165 [00:32<8:13:19,  1.83s/it]

train loss: 3.294266939163208


  0%|          | 19/16165 [00:34<8:08:21,  1.81s/it]

train loss: 3.881500244140625


  0%|          | 20/16165 [00:36<8:08:58,  1.82s/it]

train loss: 4.423997402191162


  0%|          | 21/16165 [00:38<8:09:48,  1.82s/it]

train loss: 3.69490909576416


  0%|          | 22/16165 [00:40<8:15:40,  1.84s/it]

train loss: 4.680509090423584


  0%|          | 23/16165 [00:42<8:17:52,  1.85s/it]

train loss: 3.910386562347412


  0%|          | 24/16165 [00:43<8:15:00,  1.84s/it]

train loss: 3.2180047035217285


  0%|          | 25/16165 [00:45<8:17:37,  1.85s/it]

train loss: 3.052232265472412


  0%|          | 26/16165 [00:47<8:20:34,  1.86s/it]

train loss: 3.3935508728027344


  0%|          | 27/16165 [00:49<8:15:25,  1.84s/it]

train loss: 3.4846038818359375


  0%|          | 28/16165 [00:51<8:11:23,  1.83s/it]

train loss: 3.227027654647827


  0%|          | 29/16165 [00:53<8:14:21,  1.84s/it]

train loss: 2.836740016937256


  0%|          | 30/16165 [00:55<8:16:32,  1.85s/it]

train loss: 2.9687929153442383


  0%|          | 31/16165 [00:56<8:13:33,  1.84s/it]

train loss: 3.7637314796447754


  0%|          | 32/16165 [00:58<8:11:10,  1.83s/it]

train loss: 2.8456239700317383


  0%|          | 33/16165 [01:00<8:12:53,  1.83s/it]

train loss: 3.324714183807373


  0%|          | 34/16165 [01:12<21:15:56,  4.75s/it]

train loss: 2.5086212158203125


  0%|          | 35/16165 [01:23<30:11:30,  6.74s/it]

train loss: 2.3986659049987793


  0%|          | 36/16165 [01:34<36:33:03,  8.16s/it]

train loss: 2.5228724479675293


  0%|          | 37/16165 [01:46<41:05:21,  9.17s/it]

train loss: 3.9952316284179688


  0%|          | 38/16165 [01:48<31:18:03,  6.99s/it]

train loss: 2.686800956726074


  0%|          | 39/16165 [01:50<24:27:45,  5.46s/it]

train loss: 2.8970258235931396


  0%|          | 40/16165 [01:52<19:37:16,  4.38s/it]

train loss: 2.2818715572357178


  0%|          | 41/16165 [02:03<29:15:00,  6.53s/it]

train loss: 2.1446077823638916


  0%|          | 42/16165 [02:15<35:57:34,  8.03s/it]

train loss: 2.179285764694214


  0%|          | 43/16165 [02:26<40:37:19,  9.07s/it]

train loss: 2.423391819000244


  0%|          | 44/16165 [02:38<43:50:16,  9.79s/it]

train loss: 1.9432594776153564


  0%|          | 45/16165 [02:49<46:04:26, 10.29s/it]

train loss: 3.2182183265686035


  0%|          | 46/16165 [02:51<34:42:12,  7.75s/it]

train loss: 2.52327299118042


  0%|          | 47/16165 [02:53<26:45:38,  5.98s/it]

train loss: 2.8417258262634277


  0%|          | 48/16165 [02:55<21:07:25,  4.72s/it]

train loss: 1.890350580215454


  0%|          | 49/16165 [02:56<17:10:45,  3.84s/it]

train loss: 2.3629534244537354


  0%|          | 50/16165 [02:58<14:30:39,  3.24s/it]

train loss: 2.0641233921051025


  0%|          | 51/16165 [03:00<12:37:40,  2.82s/it]

train loss: 3.091564416885376


  0%|          | 52/16165 [03:11<24:14:03,  5.41s/it]

train loss: 2.1313843727111816


  0%|          | 53/16165 [03:13<19:20:53,  4.32s/it]

train loss: 2.2723867893218994


  0%|          | 54/16165 [03:15<15:58:04,  3.57s/it]

train loss: 2.035977363586426


  0%|          | 55/16165 [03:17<13:40:17,  3.06s/it]

train loss: 1.6723631620407104


  0%|          | 56/16165 [03:19<12:03:59,  2.70s/it]

train loss: 1.8318439722061157


  0%|          | 57/16165 [03:21<10:52:06,  2.43s/it]

train loss: 2.0938332080841064


  0%|          | 58/16165 [03:22<10:02:28,  2.24s/it]

train loss: 3.266624689102173


  0%|          | 59/16165 [03:24<9:27:01,  2.11s/it] 

train loss: 1.7110854387283325


  0%|          | 60/16165 [03:26<9:07:48,  2.04s/it]

train loss: 1.8245793581008911


  0%|          | 61/16165 [03:28<8:52:20,  1.98s/it]

train loss: 1.7319635152816772


  0%|          | 62/16165 [03:30<8:41:08,  1.94s/it]

train loss: 1.5398319959640503


  0%|          | 63/16165 [03:32<8:33:28,  1.91s/it]

train loss: 1.9368301630020142


  0%|          | 64/16165 [03:33<8:24:57,  1.88s/it]

train loss: 1.9492976665496826


  0%|          | 65/16165 [03:35<8:19:22,  1.86s/it]

train loss: 1.6573718786239624


  0%|          | 66/16165 [03:37<8:13:10,  1.84s/it]

train loss: 1.8093254566192627


  0%|          | 67/16165 [03:39<8:09:38,  1.82s/it]

train loss: 1.6470898389816284


  0%|          | 68/16165 [03:41<8:08:00,  1.82s/it]

train loss: 1.6058018207550049


  0%|          | 69/16165 [03:42<8:06:34,  1.81s/it]

train loss: 1.704673171043396


  0%|          | 70/16165 [03:44<8:06:16,  1.81s/it]

train loss: 2.0907840728759766


  0%|          | 71/16165 [03:46<8:04:45,  1.81s/it]

train loss: 1.7495744228363037


  0%|          | 72/16165 [03:48<8:04:47,  1.81s/it]

train loss: 1.6342225074768066


  0%|          | 73/16165 [03:50<8:03:55,  1.80s/it]

train loss: 1.7480543851852417


  0%|          | 74/16165 [03:51<8:02:11,  1.80s/it]

train loss: 1.6019563674926758


  0%|          | 75/16165 [03:53<8:03:52,  1.80s/it]

train loss: 1.5597256422042847


  0%|          | 76/16165 [03:55<8:03:59,  1.80s/it]

train loss: 1.6267222166061401


  0%|          | 77/16165 [04:07<21:04:10,  4.71s/it]

train loss: 1.60991370677948


  0%|          | 78/16165 [04:18<30:10:20,  6.75s/it]

train loss: 1.4873415231704712


  0%|          | 79/16165 [04:30<36:29:23,  8.17s/it]

train loss: 1.671089768409729


  0%|          | 80/16165 [04:41<40:56:06,  9.16s/it]

train loss: 1.7359634637832642


  1%|          | 81/16165 [04:53<44:06:52,  9.87s/it]

train loss: 1.5214803218841553


  1%|          | 82/16165 [05:04<46:18:26, 10.37s/it]

train loss: 1.4153013229370117


  1%|          | 83/16165 [05:16<47:50:41, 10.71s/it]

train loss: 1.906777262687683


  1%|          | 84/16165 [05:27<48:54:42, 10.95s/it]

train loss: 1.4221693277359009


  1%|          | 85/16165 [05:39<49:41:28, 11.12s/it]

train loss: 1.5409469604492188


  1%|          | 86/16165 [05:50<50:13:29, 11.25s/it]

train loss: 1.4928981065750122


  1%|          | 87/16165 [06:02<50:34:46, 11.33s/it]

train loss: 1.300754427909851


  1%|          | 88/16165 [06:13<50:50:07, 11.38s/it]

train loss: 1.8629893064498901


  1%|          | 89/16165 [06:25<50:58:41, 11.42s/it]

train loss: 1.6547787189483643


  1%|          | 90/16165 [06:36<51:03:48, 11.44s/it]

train loss: 1.845826506614685


  1%|          | 91/16165 [06:48<51:07:45, 11.45s/it]

train loss: 1.2426567077636719


  1%|          | 92/16165 [06:59<51:12:27, 11.47s/it]

train loss: 1.472185730934143


  1%|          | 93/16165 [07:11<51:15:18, 11.48s/it]

train loss: 1.507523536682129


  1%|          | 94/16165 [07:22<51:12:26, 11.47s/it]

train loss: 1.4616868495941162


  1%|          | 95/16165 [07:34<51:15:15, 11.48s/it]

train loss: 1.452836275100708


  1%|          | 96/16165 [07:45<51:17:26, 11.49s/it]

train loss: 1.4268378019332886


  1%|          | 97/16165 [07:57<51:20:05, 11.50s/it]

train loss: 1.5799334049224854


  1%|          | 98/16165 [08:08<51:19:57, 11.50s/it]

train loss: 1.4366884231567383


  1%|          | 99/16165 [08:20<51:22:18, 11.51s/it]

train loss: 1.478736400604248


  1%|          | 100/16165 [08:31<51:22:19, 11.51s/it]

train loss: 1.411198616027832


  1%|          | 101/16165 [08:43<51:20:06, 11.50s/it]

This was the generated caption: ['* * ekern _ 1 * 0 < t > * * ekern _ 1. 0 < b > * cleff4 < t > * clefg2 < t > * k [ f * k [ f * k [ f * k [ f * k [ f * k [ f * k [ f * k [ f * k [ f * k [ f * k [ f * ] < t > * k [ f * g < t > * ] < t > * ] < t > * k [ f * ] < t > * g * g * k [ f * g * g * < t > * g * k [ f * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * < t > * < t > * g * g * g * g * g * g * g * g * g * g * g * g * g * g * < t > * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * g * < t > * g * g * g * g * < t > * g * g * g * < t > * < t > * g * g * g * g * g * g * g * < t > * < t > * g * < t > * g * g * g * g * < t > * g * g * < t > * g * < t > * < t > * g * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t > * < t >

  1%|          | 102/16165 [08:54<51:20:09, 11.51s/it]

train loss: 1.2774980068206787


  1%|          | 103/16165 [09:06<51:18:16, 11.50s/it]

train loss: 1.3346627950668335


  1%|          | 104/16165 [09:17<51:20:42, 11.51s/it]

train loss: 1.282429814338684


  1%|          | 105/16165 [09:29<51:20:26, 11.51s/it]

train loss: 1.4645605087280273


  1%|          | 106/16165 [09:40<51:18:28, 11.50s/it]

train loss: 1.4098286628723145


  1%|          | 107/16165 [09:52<51:15:27, 11.49s/it]

train loss: 1.346493124961853


  1%|          | 108/16165 [10:03<51:15:07, 11.49s/it]

train loss: 1.2513110637664795


  1%|          | 109/16165 [10:15<51:14:40, 11.49s/it]

train loss: 1.4049180746078491


  1%|          | 110/16165 [10:26<51:14:12, 11.49s/it]

train loss: 1.3986362218856812


  1%|          | 111/16165 [10:38<51:14:02, 11.49s/it]

train loss: 1.2847589254379272


  1%|          | 112/16165 [10:49<51:12:42, 11.48s/it]

train loss: 1.58637535572052


  1%|          | 113/16165 [11:01<51:11:41, 11.48s/it]

train loss: 1.4530786275863647


  1%|          | 114/16165 [11:12<51:10:06, 11.48s/it]

train loss: 1.1805167198181152


  1%|          | 115/16165 [11:23<51:08:51, 11.47s/it]

train loss: 1.7371546030044556


  1%|          | 116/16165 [11:35<51:08:17, 11.47s/it]

train loss: 1.510024070739746


  1%|          | 117/16165 [11:46<51:07:58, 11.47s/it]

train loss: 1.336954116821289


  1%|          | 118/16165 [11:58<51:07:18, 11.47s/it]

train loss: 1.2291557788848877


  1%|          | 119/16165 [12:09<51:08:19, 11.47s/it]

train loss: 1.2627068758010864


  1%|          | 120/16165 [12:21<51:07:57, 11.47s/it]

train loss: 1.352950096130371


  1%|          | 121/16165 [12:32<51:08:46, 11.48s/it]

train loss: 1.236096978187561


  1%|          | 122/16165 [12:44<51:08:17, 11.48s/it]

train loss: 1.329064965248108


  1%|          | 123/16165 [12:55<51:07:31, 11.47s/it]

train loss: 1.345861554145813


  1%|          | 124/16165 [13:07<51:06:47, 11.47s/it]

train loss: 1.414306879043579


  1%|          | 125/16165 [13:18<51:07:29, 11.47s/it]

train loss: 1.2459750175476074


  1%|          | 126/16165 [13:30<51:08:07, 11.48s/it]

train loss: 1.553471326828003


  1%|          | 127/16165 [13:41<51:12:11, 11.49s/it]

train loss: 1.3085546493530273


  1%|          | 128/16165 [13:53<51:14:55, 11.50s/it]

train loss: 1.3389090299606323


  1%|          | 129/16165 [14:04<51:13:42, 11.50s/it]

train loss: 1.2377700805664062


  1%|          | 130/16165 [14:16<51:13:58, 11.50s/it]

train loss: 1.414342999458313


  1%|          | 131/16165 [14:27<51:14:27, 11.50s/it]

train loss: 1.4466837644577026


  1%|          | 132/16165 [14:39<51:16:10, 11.51s/it]

train loss: 1.302763819694519


  1%|          | 133/16165 [14:50<51:14:14, 11.51s/it]

train loss: 1.2456114292144775


  1%|          | 134/16165 [15:02<51:16:07, 11.51s/it]

train loss: 1.2949347496032715


  1%|          | 135/16165 [15:13<51:14:03, 11.51s/it]

train loss: 1.2936718463897705


  1%|          | 136/16165 [15:25<51:12:49, 11.50s/it]

train loss: 1.1797266006469727


  1%|          | 137/16165 [15:36<51:13:05, 11.50s/it]

train loss: 1.2363743782043457


  1%|          | 138/16165 [15:48<51:13:11, 11.51s/it]

train loss: 1.3926244974136353


  1%|          | 139/16165 [15:59<51:14:24, 11.51s/it]

train loss: 1.2701929807662964


  1%|          | 140/16165 [16:11<51:11:17, 11.50s/it]

train loss: 1.177018165588379


  1%|          | 141/16165 [16:22<51:13:59, 11.51s/it]

train loss: 1.2385213375091553


  1%|          | 142/16165 [16:34<51:13:53, 11.51s/it]

train loss: 1.2808016538619995


  1%|          | 143/16165 [16:45<51:13:02, 11.51s/it]

train loss: 1.191558599472046


  1%|          | 144/16165 [16:57<51:14:14, 11.51s/it]

train loss: 1.22603440284729


  1%|          | 145/16165 [17:08<51:13:19, 11.51s/it]

train loss: 1.1429390907287598


  1%|          | 146/16165 [17:20<51:15:19, 11.52s/it]

train loss: 1.299393653869629


  1%|          | 147/16165 [17:31<51:11:14, 11.50s/it]

train loss: 1.2983863353729248


  1%|          | 148/16165 [17:43<51:08:01, 11.49s/it]

train loss: 1.1044694185256958


  1%|          | 149/16165 [17:54<51:07:09, 11.49s/it]

train loss: 1.130153775215149


  1%|          | 150/16165 [18:06<51:20:51, 11.54s/it]

train loss: 1.1724272966384888


  1%|          | 151/16165 [18:18<51:21:27, 11.55s/it]

train loss: 1.190872311592102


  1%|          | 152/16165 [18:29<51:17:17, 11.53s/it]

train loss: 1.0531201362609863


  1%|          | 153/16165 [18:41<51:13:59, 11.52s/it]

train loss: 1.1133413314819336


  1%|          | 154/16165 [18:52<51:11:26, 11.51s/it]

train loss: 1.235416293144226


  1%|          | 155/16165 [19:04<51:10:57, 11.51s/it]

train loss: 1.0119330883026123


  1%|          | 156/16165 [19:15<51:08:52, 11.50s/it]

train loss: 1.2120178937911987


  1%|          | 157/16165 [19:27<51:07:36, 11.50s/it]

train loss: 1.0886059999465942


  1%|          | 158/16165 [19:38<51:06:58, 11.50s/it]

train loss: 1.0553743839263916


  1%|          | 159/16165 [19:50<51:06:08, 11.49s/it]

train loss: 1.204396367073059


  1%|          | 160/16165 [20:01<51:05:15, 11.49s/it]

train loss: 1.0592732429504395


  1%|          | 161/16165 [20:13<51:09:48, 11.51s/it]

train loss: 1.1422297954559326


  1%|          | 162/16165 [20:24<51:15:03, 11.53s/it]

train loss: 1.1898399591445923


  1%|          | 163/16165 [20:36<51:17:09, 11.54s/it]

train loss: 1.0379115343093872


  1%|          | 164/16165 [20:47<51:16:21, 11.54s/it]

train loss: 1.0366129875183105


  1%|          | 165/16165 [20:59<51:24:07, 11.57s/it]

train loss: 1.2017405033111572


  1%|          | 166/16165 [21:10<51:26:00, 11.57s/it]

train loss: 1.207945466041565


  1%|          | 167/16165 [21:22<51:24:11, 11.57s/it]

train loss: 1.0332285165786743


  1%|          | 168/16165 [21:34<51:18:38, 11.55s/it]

train loss: 1.1726514101028442


  1%|          | 169/16165 [21:45<51:17:04, 11.54s/it]

train loss: 1.1108319759368896


  1%|          | 170/16165 [21:57<51:19:55, 11.55s/it]

train loss: 1.1846256256103516


  1%|          | 171/16165 [22:08<51:21:15, 11.56s/it]

train loss: 1.101175308227539


  1%|          | 172/16165 [22:20<51:34:10, 11.61s/it]

train loss: 1.1014795303344727


  1%|          | 173/16165 [22:32<51:34:49, 11.61s/it]

train loss: 1.024963617324829


  1%|          | 174/16165 [22:43<51:33:15, 11.61s/it]

train loss: 1.176780343055725


  1%|          | 175/16165 [22:55<52:14:55, 11.76s/it]

train loss: 1.1238194704055786


  1%|          | 176/16165 [23:07<52:31:04, 11.82s/it]

train loss: 1.1101021766662598


  1%|          | 177/16165 [23:19<52:19:12, 11.78s/it]

train loss: 1.0631252527236938


  1%|          | 178/16165 [23:30<52:02:41, 11.72s/it]

train loss: 1.2398301362991333


  1%|          | 179/16165 [23:42<51:50:56, 11.68s/it]

train loss: 1.329732894897461


  1%|          | 180/16165 [23:54<51:36:44, 11.62s/it]

train loss: 1.0501765012741089


  1%|          | 181/16165 [24:05<51:24:58, 11.58s/it]

train loss: 1.0510218143463135


  1%|          | 182/16165 [24:17<51:17:32, 11.55s/it]

train loss: 1.1235727071762085


  1%|          | 183/16165 [24:28<51:23:45, 11.58s/it]

train loss: 0.8601868748664856


  1%|          | 184/16165 [24:40<51:21:31, 11.57s/it]

train loss: 1.1207340955734253


  1%|          | 185/16165 [24:52<51:45:02, 11.66s/it]

train loss: 1.213465690612793


  1%|          | 186/16165 [25:03<51:58:41, 11.71s/it]

train loss: 0.9733908772468567


  1%|          | 187/16165 [25:15<51:50:50, 11.68s/it]

train loss: 1.1419981718063354


  1%|          | 188/16165 [25:27<51:38:15, 11.64s/it]

train loss: 1.2050844430923462


  1%|          | 189/16165 [25:38<51:46:24, 11.67s/it]

train loss: 0.9869406223297119


  1%|          | 190/16165 [25:50<51:36:47, 11.63s/it]

train loss: 1.053209662437439


  1%|          | 191/16165 [26:01<51:30:00, 11.61s/it]

train loss: 1.1133745908737183


  1%|          | 192/16165 [26:13<51:23:02, 11.58s/it]

train loss: 1.00644052028656


  1%|          | 193/16165 [26:24<51:17:28, 11.56s/it]

train loss: 0.9400385618209839


  1%|          | 194/16165 [26:36<51:12:40, 11.54s/it]

train loss: 1.0672733783721924


  1%|          | 195/16165 [26:47<51:08:24, 11.53s/it]

train loss: 1.067376971244812


  1%|          | 196/16165 [26:59<51:06:29, 11.52s/it]

train loss: 0.9564799666404724


  1%|          | 197/16165 [27:10<51:01:37, 11.50s/it]

train loss: 1.2743339538574219


  1%|          | 198/16165 [27:22<51:00:21, 11.50s/it]

train loss: 1.1365363597869873


  1%|          | 199/16165 [27:34<51:12:17, 11.55s/it]

train loss: 0.9430329203605652


  1%|          | 200/16165 [27:45<51:07:34, 11.53s/it]

train loss: 1.117624282836914


  1%|          | 201/16165 [27:56<51:02:10, 11.51s/it]

This was the generated caption: ['* * ekern _ 1. 0 < t > * * ekern _ 1. 0 < b > * cleff4 < t > * clefg2 < b > * k [ b - e - a - ] < t > * k [ b - ] < t > * m3 / 4 < t > * m3 / 4 < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t > = - < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >. < t >.', '* * ekern _ 1. 0 < t >

  1%|          | 202/16165 [28:08<50:58:13, 11.49s/it]

train loss: 1.0041110515594482


  1%|▏         | 203/16165 [28:19<50:57:15, 11.49s/it]

train loss: 1.0244061946868896


  1%|▏         | 204/16165 [28:31<50:55:31, 11.49s/it]

train loss: 1.0567632913589478


  1%|▏         | 205/16165 [28:42<50:54:05, 11.48s/it]

train loss: 1.0217633247375488


  1%|▏         | 206/16165 [28:54<50:50:50, 11.47s/it]

train loss: 1.1500898599624634


  1%|▏         | 207/16165 [29:05<50:50:05, 11.47s/it]

train loss: 1.1129792928695679


  1%|▏         | 208/16165 [29:17<50:48:09, 11.46s/it]

train loss: 1.1309643983840942


  1%|▏         | 209/16165 [29:28<50:48:18, 11.46s/it]

train loss: 0.9383618831634521


  1%|▏         | 210/16165 [29:40<50:47:01, 11.46s/it]

train loss: 0.9397245049476624


  1%|▏         | 211/16165 [29:51<50:46:05, 11.46s/it]

train loss: 1.1493059396743774


  1%|▏         | 212/16165 [30:03<50:46:36, 11.46s/it]

train loss: 0.9849404096603394


  1%|▏         | 213/16165 [30:14<50:46:54, 11.46s/it]

train loss: 1.0851322412490845


  1%|▏         | 214/16165 [30:25<50:47:07, 11.46s/it]

train loss: 1.0483537912368774


  1%|▏         | 215/16165 [30:37<50:47:08, 11.46s/it]

train loss: 0.984506368637085


  1%|▏         | 216/16165 [30:48<50:47:42, 11.47s/it]

train loss: 0.8961498141288757


  1%|▏         | 217/16165 [31:00<50:49:23, 11.47s/it]

train loss: 1.100547194480896


  1%|▏         | 218/16165 [31:11<50:48:43, 11.47s/it]

train loss: 0.8555433750152588


  1%|▏         | 219/16165 [31:23<50:52:25, 11.49s/it]

train loss: 1.220692753791809


  1%|▏         | 220/16165 [31:34<50:52:21, 11.49s/it]

train loss: 1.0985935926437378


  1%|▏         | 221/16165 [31:46<50:52:22, 11.49s/it]

train loss: 1.017150640487671


  1%|▏         | 222/16165 [31:57<50:52:35, 11.49s/it]

train loss: 1.065344214439392


  1%|▏         | 223/16165 [32:09<50:53:47, 11.49s/it]

train loss: 0.920652449131012


  1%|▏         | 224/16165 [32:20<50:56:00, 11.50s/it]

train loss: 1.1316008567810059


  1%|▏         | 225/16165 [32:32<50:57:36, 11.51s/it]

train loss: 0.9287078976631165


  1%|▏         | 226/16165 [32:43<50:55:38, 11.50s/it]

train loss: 1.0159804821014404


  1%|▏         | 227/16165 [32:55<50:54:09, 11.50s/it]

train loss: 0.9845919013023376


  1%|▏         | 228/16165 [33:06<50:54:55, 11.50s/it]

train loss: 0.9749624729156494


  1%|▏         | 229/16165 [33:18<50:53:24, 11.50s/it]

train loss: 1.014125943183899


  1%|▏         | 230/16165 [33:29<50:52:31, 11.49s/it]

train loss: 0.9126140475273132


  1%|▏         | 231/16165 [33:41<50:51:14, 11.49s/it]

train loss: 1.0262681245803833


  1%|▏         | 232/16165 [33:52<50:52:11, 11.49s/it]

train loss: 0.9497554302215576


  1%|▏         | 233/16165 [34:04<50:51:45, 11.49s/it]

train loss: 1.0313963890075684


  1%|▏         | 234/16165 [34:15<50:51:33, 11.49s/it]

train loss: 1.0247355699539185


  1%|▏         | 235/16165 [34:27<50:51:11, 11.49s/it]

train loss: 1.0228140354156494


  1%|▏         | 236/16165 [34:38<50:50:29, 11.49s/it]

train loss: 0.9720300436019897


  1%|▏         | 237/16165 [34:50<50:51:12, 11.49s/it]

train loss: 0.996263325214386


  1%|▏         | 238/16165 [35:01<50:51:42, 11.50s/it]

train loss: 1.0791693925857544


  1%|▏         | 239/16165 [35:13<50:52:54, 11.50s/it]

train loss: 1.0513793230056763


  1%|▏         | 240/16165 [35:24<50:53:31, 11.50s/it]

train loss: 0.9164827466011047


  1%|▏         | 241/16165 [35:36<50:51:39, 11.50s/it]

train loss: 1.1042430400848389


  1%|▏         | 242/16165 [35:47<50:50:28, 11.49s/it]

train loss: 1.0477560758590698


  2%|▏         | 243/16165 [35:59<50:48:01, 11.49s/it]

train loss: 0.898464024066925


  2%|▏         | 244/16165 [36:10<50:47:42, 11.49s/it]

train loss: 1.0106923580169678


  2%|▏         | 245/16165 [36:22<50:46:48, 11.48s/it]

train loss: 1.0158859491348267


  2%|▏         | 246/16165 [36:33<50:46:52, 11.48s/it]

train loss: 0.875169038772583


  2%|▏         | 247/16165 [36:45<50:46:59, 11.49s/it]

train loss: 0.8453649878501892


  2%|▏         | 248/16165 [36:56<50:46:46, 11.48s/it]

train loss: 1.047399640083313


  2%|▏         | 249/16165 [37:08<50:46:46, 11.49s/it]

train loss: 0.9300259351730347


  2%|▏         | 250/16165 [37:19<50:46:47, 11.49s/it]

train loss: 0.9363964796066284


  2%|▏         | 251/16165 [37:31<50:45:30, 11.48s/it]

train loss: 1.0323035717010498


  2%|▏         | 252/16165 [37:42<50:44:24, 11.48s/it]

train loss: 0.9609911441802979


  2%|▏         | 253/16165 [37:54<50:42:53, 11.47s/it]

train loss: 0.9790931940078735


  2%|▏         | 254/16165 [38:05<50:43:45, 11.48s/it]

train loss: 0.9920989871025085


  2%|▏         | 255/16165 [38:17<50:42:31, 11.47s/it]

train loss: 1.0766210556030273


  2%|▏         | 256/16165 [38:28<50:42:55, 11.48s/it]

train loss: 0.8157442212104797


  2%|▏         | 257/16165 [38:39<50:43:50, 11.48s/it]

train loss: 0.8724854588508606


  2%|▏         | 258/16165 [38:51<50:42:48, 11.48s/it]

train loss: 0.8949875235557556


  2%|▏         | 259/16165 [39:02<50:41:38, 11.47s/it]

train loss: 1.1909263134002686


  2%|▏         | 260/16165 [39:14<50:42:20, 11.48s/it]

train loss: 1.121341347694397


  2%|▏         | 261/16165 [39:25<50:42:40, 11.48s/it]

train loss: 1.0033013820648193


  2%|▏         | 262/16165 [39:37<50:43:00, 11.48s/it]

train loss: 1.0109574794769287


  2%|▏         | 263/16165 [39:48<50:42:11, 11.48s/it]

train loss: 0.9631034731864929


  2%|▏         | 264/16165 [40:00<50:44:25, 11.49s/it]

train loss: 1.106876015663147


  2%|▏         | 265/16165 [40:11<50:47:17, 11.50s/it]

train loss: 1.0849391222000122


  2%|▏         | 266/16165 [40:23<50:52:25, 11.52s/it]

train loss: 0.8925009965896606


  2%|▏         | 267/16165 [40:35<50:54:09, 11.53s/it]

train loss: 1.2311139106750488


  2%|▏         | 268/16165 [40:46<50:55:35, 11.53s/it]

train loss: 1.0158977508544922


  2%|▏         | 269/16165 [40:58<50:55:30, 11.53s/it]

train loss: 0.8514686226844788


  2%|▏         | 270/16165 [41:09<50:54:07, 11.53s/it]

train loss: 1.1569963693618774


  2%|▏         | 271/16165 [41:21<50:52:14, 11.52s/it]

train loss: 0.9892265796661377


  2%|▏         | 272/16165 [41:32<50:50:35, 11.52s/it]

train loss: 0.8874739408493042


  2%|▏         | 273/16165 [41:44<50:49:20, 11.51s/it]

train loss: 0.8583935499191284


  2%|▏         | 274/16165 [41:55<50:47:30, 11.51s/it]

train loss: 1.0738155841827393


  2%|▏         | 275/16165 [42:07<50:49:04, 11.51s/it]

train loss: 1.1006832122802734


  2%|▏         | 276/16165 [42:18<50:50:11, 11.52s/it]

train loss: 0.9945022463798523


  2%|▏         | 277/16165 [42:30<50:50:32, 11.52s/it]

train loss: 1.064387559890747


  2%|▏         | 278/16165 [42:41<50:50:43, 11.52s/it]

train loss: 0.6689918041229248


  2%|▏         | 279/16165 [42:53<50:50:30, 11.52s/it]

train loss: 1.0510791540145874


In [None]:
#save the model in case the training was cancelled/"keyboard-interrupted" early bc of session limits 
best_model_wts = copy.deepcopy(model.state_dict())
torch.save(best_model_wts, 'drive/MyDrive/fine_tuned_blip_early_stop_WEIGHTS_ONLY') #change path as needed

In [None]:
#for loading in a saved set of model weights when needed
model.load_state_dict(torch.load("drive/MyDrive/fine_tuned_blip_wer_val_idx_0")) #change path as needed

### Evaluation of Fine-Tuned BLIP

Let's now see how the fine-tuned BLIP model does with the test set. Like before, eval here will be based on Word Error Rate (WER) which 
seems to be a better metric for this use case than other popular metrics like rouge-n. WER approximates the label of "incorrect" words divided by 
the total number of reference words. WER=1.0 means a perfect mismatch while WER=0.0 indicates a perfect match. So the lower the score, the better.

(WER is sensitive to whether or not predicted words are exact matches to the reference labels + the order in which predicted words appear, two properties
that are very important for our use case. E.g., if a predicted sequence has all the same words as the reference sequence but in a different order -- that is 
effectively a completely different musical score. 

The original paper that produced the raw dataset we're using also uses variations of WER to judge baseline model runs on its dataset.)

In [None]:
model.eval() # set the model to eval mode if it wasn't there already

all_gt_ids, all_gt_text = [], []
all_pred_ids, all_pred_text = [], []

rwer = 0

with torch.no_grad(): #not computing gradients while testing
  for idx, batch in enumerate(tqdm.tqdm(test_dataloader_blip)):

    #retrieve the image and gt labels
    pixel_values = batch.pop("pixel_values").to(device)
    labels = batch.pop("input_ids").to(device)

    #make prediction and get both the outputted ids and the text caption
    generated_ids = model.generate(pixel_values=pixel_values, max_new_tokens=500)
    generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)

    #get the text of the labels
    gt_caption = processor.batch_decode([i[:500] for i in labels], skip_special_tokens=True)

    wer_score = wer.compute(predictions=generated_caption, references=gt_caption)
    rwer += wer_score

    if idx % 100 == 0:
      agg_wer = rwer/(idx+1)
      print(f"Step {idx+1} -- wer: {agg_wer}")

    #store the preds and gts for eval later
    for pred_ids, pred_cap, gt_ids, gt_cap in zip(generated_ids, generated_caption, labels, gt_caption):
      all_pred_ids.append(pred_ids)
      all_pred_text.append(pred_cap)
      all_gt_ids.append(gt_ids)
      all_gt_text.append(gt_cap)

In [None]:
#saving the predictions and the gt labels bc a full test run takes over a day w/ the gpu i'm on.
#(so if more metrics need to be calc'd later or predictions need to be viewed, we can just load this back in)
with open('drive/MyDrive/all_pred_ids.pickle', 'wb') as output:
    pickle.dump(all_pred_ids, output)
with open('drive/MyDrive/all_pred_text.pickle', 'wb') as output:
    pickle.dump(all_pred_text, output)
with open('drive/MyDrive/all_gt_ids.pickle', 'wb') as output:
    pickle.dump(all_gt_ids, output)
with open('drive/MyDrive/all_gt_text.pickle', 'wb') as output:
    pickle.dump(all_gt_text, output)