# Environments

In [247]:
import sys
version = sys.version # Python version
print(version)

3.9.13 (main, Aug 25 2022, 23:51:50) [MSC v.1916 64 bit (AMD64)]


In [248]:
import torch
print(torch.__version__) # PyTorch version
print(torch.backends.cudnn.version()) # cuDNN version

1.13.0
8500


In [249]:
!nvcc --version # CUDA version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Tue_May__3_19:00:59_Pacific_Daylight_Time_2022
Cuda compilation tools, release 11.7, V11.7.64
Build cuda_11.7.r11.7/compiler.31294372_0


In [4]:
count = torch.cuda.device_count()
for i in range(count): # Available GPUs
    name = torch.cuda.get_device_name(i)
    print(i, ":", name)

0 : NVIDIA GeForce RTX 4090


# Importing Dependencies

In [233]:
from torch.utils.data import Dataset, DataLoader, random_split
from torch import autograd
import torch.optim as optim

In [38]:
import statistics

In [6]:
import csv
from collections import OrderedDict, defaultdict
import numpy as np
import json
from scipy.stats import gaussian_kde
from tqdm import tqdm

In [7]:
from music21 import *
import miditoolkit
import zipfile
import io
import random
import hashlib
from multiprocessing import Pool, Lock, Manager

In [8]:
import os
path = r"C:/Users/chowd/Jupyter/content/"
os.chdir(path)
os.chdir("./muzic/musicbert/")
print(os.getcwd())

C:\Users\chowd\Jupyter\content\muzic\musicbert


In [9]:
from musicbert import *

disable_cp = False
mask_strategy = ['bar']
convert_encoding = OCTMIDI
crop_length = None


In [10]:
def playMidi(filename):
    mf = midi.MidiFile()
    mf.open(filename)
    mf.read()
    mf.close()
    s = midi.translate.midiFileToStream(mf)
    s.show('midi')

In [11]:
def parse_midi_file(sample_midi_path: str):
    midi_obj = miditoolkit.midi.parser.MidiFile(sample_midi_path)
    midi_name = sample_midi_path.split('/')[-1].split('.')[0]
    return midi_obj, midi_name

# Preprocessing Data

In [12]:
bar_max = 256 #@param {type:"integer"}

In [13]:
pos_resolution = 16  # per beat (quarter note)

velocity_quant = 4
tempo_quant = 12  # 2 ** (1 / 12)
min_tempo = 16
max_tempo = 256
duration_max = 8  # 2 ** 8 * beat
max_ts_denominator = 6  # x/1 x/2 x/4 ... x/64
max_notes_per_bar = 2  # 1/64 ... 128/64
beat_note_factor = 4  # In MIDI format a note is always 4 beats
deduplicate = True
filter_symbolic = False
filter_symbolic_ppl = 16
trunc_pos = 2 ** 16  # approx 30 minutes (1024 measures)
sample_len_max = 1000  # window length max
sample_overlap_rate = 4
ts_filter = False
pool_num = 24
max_inst = 127
max_pitch = 127
max_velocity = 127

data_zip = None
output_file = None

In [14]:
ts_dict = dict()
ts_list = list()
for i in range(0, max_ts_denominator + 1):  # 1 ~ 64
    for j in range(1, ((2 ** i) * max_notes_per_bar) + 1):
        ts_dict[(j, 2 ** i)] = len(ts_dict)
        ts_list.append((j, 2 ** i))
dur_enc = list()
dur_dec = list()
for i in range(duration_max):
    for j in range(pos_resolution):
        dur_dec.append(len(dur_enc))
        for k in range(2 ** i):
            dur_enc.append(len(dur_dec) - 1)

In [15]:
def t2e(x):
    assert x in ts_dict, 'unsupported time signature: ' + str(x)
    return ts_dict[x]


def e2t(x):
    return ts_list[x]


def d2e(x):
    return dur_enc[x] if x < len(dur_enc) else dur_enc[-1]


def e2d(x):
    return dur_dec[x] if x < len(dur_dec) else dur_dec[-1]


def v2e(x):
    return x // velocity_quant


def e2v(x):
    return (x * velocity_quant) + (velocity_quant // 2)


def b2e(x):
    x = max(x, min_tempo)
    x = min(x, max_tempo)
    x = x / min_tempo
    e = round(math.log2(x) * tempo_quant)
    return e


def e2b(x):
    return 2 ** (x / tempo_quant) * min_tempo

In [16]:
def time_signature_reduce(numerator, denominator):
    # reduction (when denominator is too large)
    while denominator > 2 ** max_ts_denominator and denominator % 2 == 0 and numerator % 2 == 0:
        denominator //= 2
        numerator //= 2
    # decomposition (when length of a bar exceed max_notes_per_bar)
    while numerator > max_notes_per_bar * denominator:
        for i in range(2, numerator + 1):
            if numerator % i == 0:
                numerator //= i
                break
    return numerator, denominator


def writer(output_str_list, output_file):
    # note: parameter "file_name" is reserved for patching
    with open(output_file, 'a') as f:
        for output_str in output_str_list:
            f.write(output_str + '\n')

In [17]:
def get_hash(encoding):
    # add i[4] and i[5] for stricter match
    midi_tuple = tuple((i[2], i[3]) for i in encoding)
    midi_hash = hashlib.md5(str(midi_tuple).encode('ascii')).hexdigest()
    return midi_hash

In [18]:
def MIDI_to_encoding(midi_obj):
    def time_to_pos(t):
        return round(t * pos_resolution / midi_obj.ticks_per_beat)
    notes_start_pos = [time_to_pos(j.start)
                       for i in midi_obj.instruments for j in i.notes]
    if len(notes_start_pos) == 0:
        return list()
    max_pos = min(max(notes_start_pos) + 1, trunc_pos)
    pos_to_info = [[None for _ in range(4)] for _ in range(
        max_pos)]  # (Measure, TimeSig, Pos, Tempo)
    tsc = midi_obj.time_signature_changes
    tpc = midi_obj.tempo_changes
    for i in range(len(tsc)):
        for j in range(time_to_pos(tsc[i].time), time_to_pos(tsc[i + 1].time) if i < len(tsc) - 1 else max_pos):
            if j < len(pos_to_info):
                pos_to_info[j][1] = t2e(time_signature_reduce(
                    tsc[i].numerator, tsc[i].denominator))
    for i in range(len(tpc)):
        for j in range(time_to_pos(tpc[i].time), time_to_pos(tpc[i + 1].time) if i < len(tpc) - 1 else max_pos):
            if j < len(pos_to_info):
                pos_to_info[j][3] = b2e(tpc[i].tempo)
    for j in range(len(pos_to_info)):
        if pos_to_info[j][1] is None:
            # MIDI default time signature
            pos_to_info[j][1] = t2e(time_signature_reduce(4, 4))
        if pos_to_info[j][3] is None:
            pos_to_info[j][3] = b2e(120.0)  # MIDI default tempo (BPM)
    cnt = 0
    bar = 0
    measure_length = None
    for j in range(len(pos_to_info)):
        ts = e2t(pos_to_info[j][1])
        if cnt == 0:
            measure_length = ts[0] * beat_note_factor * pos_resolution // ts[1]
        pos_to_info[j][0] = bar
        pos_to_info[j][2] = cnt
        cnt += 1
        if cnt >= measure_length:
            assert cnt == measure_length, 'invalid time signature change: pos = {}'.format(
                j)
            cnt -= measure_length
            bar += 1
    encoding = []
    start_distribution = [0] * pos_resolution
    for inst in midi_obj.instruments:
        for note in inst.notes:
            if time_to_pos(note.start) >= trunc_pos:
                continue
            start_distribution[time_to_pos(note.start) % pos_resolution] += 1
            info = pos_to_info[time_to_pos(note.start)]
            encoding.append((info[0], info[2], max_inst + 1 if inst.is_drum else inst.program, note.pitch + max_pitch +
                             1 if inst.is_drum else note.pitch, d2e(time_to_pos(note.end) - time_to_pos(note.start)), v2e(note.velocity), info[1], info[3]))
    if len(encoding) == 0:
        return list()
    tot = sum(start_distribution)
    start_ppl = 2 ** sum((0 if x == 0 else -(x / tot) *
                          math.log2((x / tot)) for x in start_distribution))
    # filter unaligned music
    if filter_symbolic:
        assert start_ppl <= filter_symbolic_ppl, 'filtered out by the symbolic filter: ppl = {:.2f}'.format(
            start_ppl)
    encoding.sort()
    return encoding

In [19]:
def str_to_encoding(s):
    encoding = [int(i[3: -1]) for i in s.split() if 's' not in i]
    tokens_per_note = 8
    assert len(encoding) % tokens_per_note == 0
    encoding = [tuple(encoding[i + j] for j in range(tokens_per_note))
                for i in range(0, len(encoding), tokens_per_note)]
    return encoding

In [20]:
def encoding_to_str(e, bar_max = bar_max):
    bar_index_offset = 0
    p = 0
    tokens_per_note = 8
    return ' '.join((['<s>'] * tokens_per_note)
                    + ['<{}-{}>'.format(j, k if j > 0 else k + bar_index_offset) for i in e[p: p +
                                                                                            sample_len_max] if i[0] + bar_index_offset < bar_max for j, k in enumerate(i)]
                    + (['</s>'] * (tokens_per_note
                                   - 1)))   # 8 - 1 for append_eos functionality of binarizer in fairseq

In [21]:
lock_file = Lock()
lock_write = Lock()
lock_set = Lock()
manager = Manager()
midi_dict = manager.dict()

e = None
midi_file = None
file_name = None

def encode_midi(midi_file):
    midi_obj = miditoolkit.midi.parser.MidiFile(file=midi_file)
    
    assert all(0 <= j.start < 2 ** 31 and 0 <= j.end < 2 **
               31 for i in midi_obj.instruments for j in i.notes), 'bad note time'
    assert all(0 < j.numerator < 2 ** 31 and 0 < j.denominator < 2 **
               31 for j in midi_obj.time_signature_changes), 'bad time signature value'
    assert 0 < midi_obj.ticks_per_beat < 2 ** 31, 'bad ticks per beat'
    
    midi_notes_count = sum(len(inst.notes) for inst in midi_obj.instruments)
    if midi_notes_count == 0:
        print('ERROR(BLANK): ' + file_name + '\n', end='')
        return None
    
    e = MIDI_to_encoding(midi_obj)
    length = len(e)
    
    if length == 0:
        print('ERROR(BLANK): ' + file_name + '\n', end='')
        return None
    if ts_filter:
        allowed_ts = t2e(time_signature_reduce(4, 4))
        if not all(i[6] == allowed_ts for i in e):
            print('ERROR(TSFILT): ' + file_name + '\n', end='')
            return None

    dup_file_name = ''
    midi_hash = '0' * 32

    midi_hash = get_hash(e)
    lock_set.acquire()
    if midi_hash in midi_dict:
        dup_file_name = midi_dict[midi_hash]
        duplicated = True
    else:
        midi_dict[midi_hash] = file_name
        duplicated = False
    lock_set.release()

    output_str_list = []
    sample_step = max(round(sample_len_max / sample_overlap_rate), 1)
    for p in range(0 - random.randint(0, sample_len_max - 1), length, sample_step):
        L = max(p, 0)
        R = min(p + sample_len_max, length) - 1
        bar_index_list = [e[i][0] for i in range(L, R + 1) if e[i][0] is not None]
        bar_index_min = 0
        bar_index_max = 0
        if len(bar_index_list) > 0:
            bar_index_min = min(bar_index_list)
            bar_index_max = max(bar_index_list)
        offset_lower_bound = -bar_index_min
        offset_upper_bound = bar_max - 1 - bar_index_max
        # to make bar index distribute in [0, bar_max)
        bar_index_offset = random.randint(
            offset_lower_bound, offset_upper_bound) if offset_lower_bound <= offset_upper_bound else offset_lower_bound
        e_segment = []
        for i in e[L: R + 1]:
            if i[0] is None or i[0] + bar_index_offset < bar_max:
                e_segment.append(i)
            else:
                break
        tokens_per_note = 8
        output_words = (['<s>'] * tokens_per_note) \
            + [('<{}-{}>'.format(j, k if j > 0 else k + bar_index_offset) 
                if k is not None else '<unk>') for i in e_segment for j, k in enumerate(i)] \
            + (['</s>'] * (tokens_per_note - 1)
               )  # tokens_per_note - 1 for append_eos functionality of binarizer in fairseq
        output_str_list.append(' '.join(output_words))

    # no empty
    if not all(len(i.split()) > tokens_per_note * 2 - 1 for i in output_str_list):
        print('ERROR(ENCODE): ' + file_name + ' ' + str(e) + '\n', end='')
        return None

    return output_str_list
    


def F(data_path, output_path):
    midi_file_list = dict()
    
    lock_file.acquire()
    
    data_zip = zipfile.ZipFile(data_path, 'r')
    file_name_list = data_zip.namelist()
    
    lock_file.release()
    
    lock_write.acquire()
    
    for file_name in file_name_list:
        with data_zip.open(file_name) as f:
            if file_name[-3:] != "mid":
                continue
            # this may fail due to unknown bug
            midi_file = io.BytesIO(f.read())
        
        encoded_midi = encode_midi(midi_file)
        output_file = output_path + "/" + file_name.split("/")[-1].split(".")[0] + ".txt"
        writer(encoded_midi, output_file)
    
    lock_write.release()
   
    return True

In [21]:
os.chdir(path)
data_path = r'./segmented_midi.zip'
output_path = r'./encoded_tokens'
result = F(data_path, output_path)
print(result)

True


In [22]:
LABEL_LIST = ["Stable beat", "Mechanical Tempo", "Intensional", "Regular beat change", "Long", "Cushioned", "Saturated (wet)", "Clean", "Subtle change", "Even", "Rich", "Bright", 
"Pure", "Soft", "Sophisticated(mellow)", "balanced", "Large range of dynamic", "Fast paced", "Flowing", "Swing(Flexible)", "Flat", "Harmonious", "Optimistic(pleasant)", "HIgh Energy", 
"Dominant(forceful)", "Imaginative", "Ethereal", "Convincing"]
LABEL_MAP = {i: label for i, label in enumerate(LABEL_LIST)}
PIANIST_MAP = OrderedDict()

In [23]:
def estimate_maxima(data):
    """Calculate maxima using kernel density estimation"""
    if len(set(data))<=1: # all datas are equal
        return data[0]
    kde = gaussian_kde(data)
    no_samples = 50
    samples = np.linspace(min(data), max(data), no_samples)
    probs = kde.evaluate(samples)
    #maxima_index = probs.argmax()
    # in case if more than 1 argmaxs
    winner = np.argwhere(probs == np.amax(probs))
    maxima = np.average(samples[winner.flatten()])
    return maxima

In [24]:
def midi_label_map_apex(file, target):

    csvreader = csv.reader(file)
    header = []
    header = next(csvreader)

    rows = []
    for row in csvreader:
        rows.append(row)

    # sort by each segments
    music_label_map = defaultdict(list)
    for row in rows:
        user = row[0]
        file_name = row[2].split(".")[0]
        #label_row = row[3:-2]
        label_row = [row[3]] + row[7:-2] # skip 1-2 ~ 1-3
        for idx, elem in enumerate(label_row):
            if elem == "":
                label_row[idx] = 0.0
            else:
                label_row[idx] = float(elem)
        # skip 0
        if 0.0 in label_row:
            continue
        else:
            music_label_map[file_name].append(label_row)

    music_label_map_apex = dict()

    # kernel density estimation
    for key, annot_list in tqdm(music_label_map.items()):
        annot_list = np.array(annot_list).transpose()
        maxima = np.array([estimate_maxima(row)/7 for row in annot_list])
        maxima = maxima.transpose().tolist()
        music_label_map_apex[key] = maxima

    # add pianist info
    for key, annot_list in tqdm(music_label_map_apex.items()):
        if key.split("_")[-2] not in PIANIST_MAP:
            PIANIST_MAP[key.split("_")[-2]] = len(PIANIST_MAP)
    print(PIANIST_MAP)
    
    for key, annot_list in tqdm(music_label_map_apex.items()):
        music_label_map_apex[key].append(PIANIST_MAP[key.split("_")[-2]])

    json.dump(music_label_map_apex, target)

os.chdir(path)
file = open('./total.csv', encoding="utf-8")
target = open("./muzic/musicbert/midi_label_map_apex_reg_cls.json", 'w')
midi_label_map_apex(file, target)

100%|███████████████████████████████████████████████████████████████████████████████| 846/846 [00:02<00:00, 416.46it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████| 846/846 [00:00<?, ?it/s]


OrderedDict([('1', 0), ('2', 1), ('4', 2), ('5', 3), ('Score', 4), ('0', 5), ('3', 6), ('9', 7), ('10', 8), ('7', 9), ('8', 10), ('6', 11), ('12', 12)])


100%|████████████████████████████████████████████████████████████████████████████| 846/846 [00:00<00:00, 846869.02it/s]


In [25]:
file.close()
target.close()

In [59]:
def midi_label_map_apex_mode(file, target):

    csvreader = csv.reader(file)
    header = []
    header = next(csvreader)

    rows = []
    for row in csvreader:
        rows.append(row)

    # sort by each segments
    music_label_map = defaultdict(list)
    for row in rows:
        user = row[0]
        file_name = row[2].split(".")[0]
        #label_row = row[3:-2]
        label_row = [row[3]] + row[7:-2] # skip 1-2 ~ 1-3
        for idx, elem in enumerate(label_row):
            if elem == "":
                label_row[idx] = 0.0
            else:
                label_row[idx] = float(elem)
        # skip 0
        if 0.0 in label_row:
            continue
        else:
            music_label_map[file_name].append(label_row)

    music_label_map_apex = dict()

    # kernel density estimation
    for key, annot_list in tqdm(music_label_map.items()):
        annot_list = np.array(annot_list).transpose()
        maxima = np.array([statistics.mode(row) for row in annot_list])
        maxima = maxima.transpose().tolist()
        music_label_map_apex[key] = maxima

    # add pianist info
    for key, annot_list in tqdm(music_label_map_apex.items()):
        if key.split("_")[-2] not in PIANIST_MAP:
            PIANIST_MAP[key.split("_")[-2]] = len(PIANIST_MAP)
    print(PIANIST_MAP)
    
    for key, annot_list in tqdm(music_label_map_apex.items()):
        music_label_map_apex[key].append(PIANIST_MAP[key.split("_")[-2]])

    json.dump(music_label_map_apex, target)

os.chdir(path)
file = open('./total.csv', encoding="utf-8")
target = open("./muzic/musicbert/midi_label_map_apex_reg_cls_mode.json", 'w')
midi_label_map_apex_mode(file, target)

100%|█████████████████████████████████████████████████████████████████████████████| 846/846 [00:00<00:00, 16918.40it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████| 846/846 [00:00<?, ?it/s]


OrderedDict([('1', 0), ('2', 1), ('4', 2), ('5', 3), ('Score', 4), ('0', 5), ('3', 6), ('9', 7), ('10', 8), ('7', 9), ('8', 10), ('6', 11), ('12', 12)])


100%|████████████████████████████████████████████████████████████████████████████████████████| 846/846 [00:00<?, ?it/s]


In [60]:
file.close()
target.close()

# Pretrained Model

In [262]:
os.chdir(path)
os.chdir('./muzic/musicbert/')
roberta_base = MusicBERTModel.from_pretrained('.', 
  checkpoint_file = './checkpoints/checkpoint_last_musicbert_base_w_genre_head.pt',
#  user_dir='C:/Users/chowd/Jupyter/content/muzic/musicbert/musicbert'    # activate the MusicBERT plugin with this keyword
)
#,
 # data_name_or_path = '.')
print(roberta_base)

RobertaHubInterface(
  (model): MusicBERTModel(
    (encoder): MusicBERTEncoder(
      (sentence_encoder): OctupleEncoder(
        (dropout_module): FairseqDropout()
        (embed_tokens): Embedding(1237, 768, padding_idx=1)
        (embed_positions): LearnedPositionalEmbedding(8194, 768, padding_idx=1)
        (emb_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (layers): ModuleList(
          (0): TransformerSentenceEncoderLayer(
            (dropout_module): FairseqDropout()
            (activation_dropout_module): FairseqDropout()
            (self_attn): MultiheadAttention(
              (dropout_module): FairseqDropout()
              (k_proj): Linear(in_features=768, out_features=768, bias=True)
              (v_proj): Linear(in_features=768, out_features=768, bias=True)
              (q_proj): Linear(in_features=768, out_features=768, bias=True)
              (out_proj): Linear(in_features=768, out_features=768, bias=True)
            )
            (

In [263]:
roberta_base.model.max_positions()

8192

In [264]:
roberta_base1 = roberta_base.model.encoder.sentence_encoder
roberta_base.cuda()
roberta_base.eval()
roberta_base1

OctupleEncoder(
  (dropout_module): FairseqDropout()
  (embed_tokens): Embedding(1237, 768, padding_idx=1)
  (embed_positions): LearnedPositionalEmbedding(8194, 768, padding_idx=1)
  (emb_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (layers): ModuleList(
    (0): TransformerSentenceEncoderLayer(
      (dropout_module): FairseqDropout()
      (activation_dropout_module): FairseqDropout()
      (self_attn): MultiheadAttention(
        (dropout_module): FairseqDropout()
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (q_proj): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=768, out_features=3072, bias=True)
      (fc2): Linear(in_features=3072, out_features=768, bia

In [265]:
for i, (name, param) in enumerate(roberta_base1.named_parameters()):
    param.requires_grad = False

## Generate Dataset and Dataloader

In [52]:
### Get midi file list 
data_path = r'C:/Users/chowd/Jupyter/content/'
file_name = r'segmented_midi.zip'

lock_file.acquire()
data_zip = zipfile.ZipFile(data_path + file_name, 'r')
file_name_list = data_zip.namelist()
lock_file.release()

file_name_list = [data_path + x for x in file_name_list if x[-3:] == "mid"]
print(len(file_name_list))
#print(file_name_list)

1018


In [53]:
### Get midi data
midi_data = dict()
for midi_path in file_name_list:
    name = midi_path.split("/")[-1].split(".")[0]
    midi_obj = miditoolkit.midi.parser.MidiFile(midi_path)
    tokenized_string = encoding_to_str(MIDI_to_encoding(midi_obj))
    encoded_tensor = roberta_base.task.label_dictionary.encode_line(tokenized_string)
    tensor = torch.Tensor(encoded_tensor)
    reshaped_tensor = torch.reshape(tensor, (-1, 8)).cuda()
    midi_data[name] = reshaped_tensor
len(midi_data)

1018

In [62]:
### Get labels
os.chdir(path)
os.chdir('./muzic/musicbert/')
rows = list()
data = None
with open("midi_label_map_apex_reg_cls.json", 'r') as file:
    label_data = json.load(file)
#data1 = json.dumps(data)
len(label_data)

846

In [66]:
rows = list()
data = None
with open("midi_label_map_apex_reg_cls_mode.json", 'r') as file:
    label_data_mode = json.load(file)
#data1 = json.dumps(data)
len(label_data_mode)

846

In [70]:
### Sort out data
midi_keys = set(midi_data.keys())
label_keys = set(label_data.keys())
file_list = midi_keys.intersection(label_keys)
length = len(file_list)
x = list()
y = list()
y_mode = list()
for file_name in file_list:
    x.append(torch.Tensor(midi_data[file_name]).cuda())
    y.append(torch.Tensor(label_data[file_name]).cuda())
    y_mode.append(torch.Tensor(label_data_mode[file_name]).cuda())
print(length)

799


In [68]:
### Define dataset class
class MyDataSet(Dataset):
    def __init__(self, midi_data, label_data, length):
        self.x = midi_data
        self.y = label_data
        self.len = length
    
    def __getitem__(self, index):
        return (self.x[index], self.y[index])
    
    def __len__(self):
        return self.len

In [69]:
train_data = MyDataSet(x, y, length)
train_loader = DataLoader(train_data, batch_size=1, shuffle=True)

## Train Classification Model

In [71]:
def reshape(func, x):
    y = x[0][0]
    y = torch.permute(y, (1,0,2))
    y = torch.reshape(y, (-1, 8*768))
    y = torch.stack([
        func(x) for x in torch.unbind(y, dim=0)
    ], dim=0)
    return y

In [94]:
class Classifier(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(
        self,
        pretrained_model,
        input_dim,
        num_classes,
        activation_fn,
        pooler_dropout
    ):
        super().__init__()
        self.pretrained_model = pretrained_model
        self.dropout = nn.Dropout(p=pooler_dropout)
        self.dense = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        y = self.pretrained_model.forward(x, last_state_only=True)
        y = reshape(self.pretrained_model.downsampling, y)
        y = y[0]  # take <s> token (equiv. to [CLS])
        y = self.dropout(y)
        y = self.dense(y)
        return y
    
    def freeze_pretrained_model(self):
        for i, (name, param) in enumerate(self.pretrained_model.named_parameters()):
            param.requires_grad = False
            
    def unfreeze_pretrained_model(self):
        for i, (name, param) in enumerate(self.pretrained_model.named_parameters()):
            param.requires_grad = True

In [95]:
generator = torch.Generator()
generator.manual_seed(0)

dataset = MyDataSet(x, y_mode, length)
dataset_size = len(dataset)
train_size = int(dataset_size * 0.8)
val_size = int(dataset_size * 0.1)
test_size = dataset_size - train_size - val_size
train_data, val_data, test_data = random_split(dataset=dataset,
                                               lengths=[train_size, val_size, test_size],
                                               generator=generator)
train_loader = DataLoader(train_data, batch_size=1, shuffle=True)
val_loader = DataLoader(val_data, batch_size=1, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1, shuffle=True)

In [105]:
model = Classifier(pretrained_model = roberta_base1,
                  input_dim = 768,
                  num_classes = 7*25,
                  activation_fn = nn.ReLU(),
                  pooler_dropout = 0.1
                  )
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.1)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

Classifier(
  (pretrained_model): OctupleEncoder(
    (dropout_module): FairseqDropout()
    (embed_tokens): Embedding(1237, 768, padding_idx=1)
    (embed_positions): LearnedPositionalEmbedding(8194, 768, padding_idx=1)
    (emb_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (layers): ModuleList(
      (0): TransformerSentenceEncoderLayer(
        (dropout_module): FairseqDropout()
        (activation_dropout_module): FairseqDropout()
        (self_attn): MultiheadAttention(
          (dropout_module): FairseqDropout()
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=768, out_features=3072, bi

In [150]:
autograd.set_detect_anomaly(True)
epoch = 20
batch_size = 8
train_loss = 0
val_loss = 0
test_loss = 0
for epoch in range(epoch):
    loss_sum = 0
    correct = 0
    for (i, data) in enumerate(train_loader):
        inputs, labels = data
        labels = labels.type(torch.LongTensor)
        inputs, labels = inputs.to(device), labels.to(device)
        y_pred = torch.Tensor(model(inputs.squeeze(dim = 0)))
        y_pred = torch.reshape(y_pred, (-1,7))
        y_pred_mode = torch.argmax(y_pred, 1)
        label = labels[0][:-1]-1
        correct = correct + torch.sum(label == y_pred_mode).item()
        loss = criterion(y_pred, label)
        loss_sum = loss_sum + loss
        if (len(train_loader) - i) % batch_size == 1:
            loss_sum = loss_sum / len(train_loader)
            train_loss = loss_sum.item()
            train_acc = correct / (len(train_loader) * 25)
            optimizer.zero_grad()
            loss_sum.backward()
            optimizer.step()
            loss_sum = 0
    with torch.no_grad():
        correct = 0
        for (i, data) in enumerate(val_loader):
            count = count + 1
            inputs, labels = data
            labels = labels.type(torch.LongTensor)
            inputs, labels = inputs.to(device), labels.to(device)
            y_pred = torch.Tensor(model(inputs.squeeze(dim = 0)))
            y_pred = torch.reshape(y_pred, (-1,7))
            y_pred_mode = torch.argmax(y_pred, 1)
            label = labels[0][:-1]-1
            correct = correct + torch.sum(label == y_pred_mode).item()
        val_acc = correct / (len(val_loader) * 25)
    print("Epoch: {0}, Train accuracy: {1:0.5f}, Validation accuracy: {2:0.5f}".format(epoch, train_acc, val_acc))
with torch.no_grad():
        correct = 0
        for (i, data) in enumerate(test_loader):
            inputs, labels = data
            labels = labels.type(torch.LongTensor)
            inputs, labels = inputs.to(device), labels.to(device)
            y_pred = torch.Tensor(model(inputs.squeeze(dim = 0)))
            y_pred = torch.reshape(y_pred, (-1,7))
            y_pred_mode = torch.argmax(y_pred, 1)
            label = labels[0][:-1]-1
            correct = correct + torch.sum(label == y_pred_mode).item()
        test_acc = correct / (len(test_loader) * 25)
print("Test accuracy: {0:0.5f}".format(test_acc))

Epoch: 0, Train accuracy: 0.27524, Validation accuracy: 0.28911
Epoch: 1, Train accuracy: 0.28263, Validation accuracy: 0.27646
Epoch: 2, Train accuracy: 0.27674, Validation accuracy: 0.27443
Epoch: 3, Train accuracy: 0.28213, Validation accuracy: 0.28051
Epoch: 4, Train accuracy: 0.28401, Validation accuracy: 0.28506
Epoch: 5, Train accuracy: 0.28457, Validation accuracy: 0.28911
Epoch: 6, Train accuracy: 0.28013, Validation accuracy: 0.27797
Epoch: 7, Train accuracy: 0.28282, Validation accuracy: 0.26582
Epoch: 8, Train accuracy: 0.27537, Validation accuracy: 0.28759
Epoch: 9, Train accuracy: 0.27994, Validation accuracy: 0.26430
Epoch: 10, Train accuracy: 0.28407, Validation accuracy: 0.27696
Epoch: 11, Train accuracy: 0.29252, Validation accuracy: 0.27848
Epoch: 12, Train accuracy: 0.28670, Validation accuracy: 0.27595
Epoch: 13, Train accuracy: 0.28363, Validation accuracy: 0.28304
Epoch: 14, Train accuracy: 0.28050, Validation accuracy: 0.28051
Epoch: 15, Train accuracy: 0.27950,

In [198]:
autograd.set_detect_anomaly(True)
epoch = 20
batch_size = 8
train_loss = 0
val_loss = 0
test_loss = 0
for epoch in range(epoch):
    loss_sum = 0
    correct = 0
    for (i, data) in enumerate(train_loader):
        inputs, labels = data
        labels = labels.type(torch.LongTensor)
        inputs, labels = inputs.to(device), labels.to(device)
        y_pred = torch.Tensor(model(inputs.squeeze(dim = 0)))
        y_pred = torch.reshape(y_pred, (-1,7))
        y_pred_mode = torch.argmax(y_pred, 1)
        
        y_pred = y_pred * weight[:, None]
        label = labels[0][:-1]-1
        correct = correct + torch.sum(label == y_pred_mode).item()
        loss = criterion(y_pred, label)
        loss_sum = loss_sum + loss
        if (len(train_loader) - i) % batch_size == 1:
            loss_sum = loss_sum / len(train_loader)
            train_loss = loss_sum.item()
            train_acc = correct / (len(train_loader) * 25)
            optimizer.zero_grad()
            loss_sum.backward()
            optimizer.step()
            loss_sum = 0
    with torch.no_grad():
        correct = 0
        for (i, data) in enumerate(val_loader):
            count = count + 1
            inputs, labels = data
            labels = labels.type(torch.LongTensor)
            inputs, labels = inputs.to(device), labels.to(device)
            y_pred = torch.Tensor(model(inputs.squeeze(dim = 0)))
            y_pred = torch.reshape(y_pred, (-1,7))
            y_pred_mode = torch.argmax(y_pred, 1)
            label = labels[0][:-1]-1
            correct = correct + torch.sum(label == y_pred_mode).item()
        val_acc = correct / (len(val_loader) * 25)
    print("Epoch: {0}, Train accuracy: {1:0.5f}, Validation accuracy: {2:0.5f}".format(epoch, train_acc, val_acc))
with torch.no_grad():
        correct = 0
        for (i, data) in enumerate(test_loader):
            inputs, labels = data
            labels = labels.type(torch.LongTensor)
            inputs, labels = inputs.to(device), labels.to(device)
            y_pred = torch.Tensor(model(inputs.squeeze(dim = 0)))
            y_pred = torch.reshape(y_pred, (-1,7))
            y_pred_mode = torch.argmax(y_pred, 1)
            label = labels[0][:-1]-1
            correct = correct + torch.sum(label == y_pred_mode).item()
        test_acc = correct / (len(test_loader) * 25)
print("Test accuracy: {0:0.5f}".format(test_acc))

Epoch: 0, Train accuracy: 0.29077, Validation accuracy: 0.28759
Epoch: 1, Train accuracy: 0.28538, Validation accuracy: 0.29418
Epoch: 2, Train accuracy: 0.29302, Validation accuracy: 0.30278
Epoch: 3, Train accuracy: 0.28482, Validation accuracy: 0.28203
Epoch: 4, Train accuracy: 0.28620, Validation accuracy: 0.29063
Epoch: 5, Train accuracy: 0.28657, Validation accuracy: 0.28557
Epoch: 6, Train accuracy: 0.27775, Validation accuracy: 0.26987
Epoch: 7, Train accuracy: 0.28670, Validation accuracy: 0.29165
Epoch: 8, Train accuracy: 0.28720, Validation accuracy: 0.29063
Epoch: 9, Train accuracy: 0.29058, Validation accuracy: 0.29316
Epoch: 10, Train accuracy: 0.29014, Validation accuracy: 0.27089
Epoch: 11, Train accuracy: 0.29271, Validation accuracy: 0.28658
Epoch: 12, Train accuracy: 0.28833, Validation accuracy: 0.28861
Epoch: 13, Train accuracy: 0.28720, Validation accuracy: 0.29468
Epoch: 14, Train accuracy: 0.28482, Validation accuracy: 0.28759
Epoch: 15, Train accuracy: 0.28951,

## Train Regression Model

In [266]:
def reshape(func, x):
    y = x[0][0]
    y = torch.permute(y, (1,0,2))
    y = torch.reshape(y, (-1, 8*768))
    y = torch.stack([
        func(x) for x in torch.unbind(y, dim=0)
    ], dim=0)
    return y

In [275]:
class Classifier(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(
        self,
        pretrained_model,
        input_dim,
        inner_dim1,
        inner_dim2,
        num_classes,
        activation_fn,
        pooler_dropout
    ):
        super().__init__()
        self.pretrained_model = pretrained_model
        self.dropout1 = nn.Dropout(p=pooler_dropout)
        self.dense1 = nn.Linear(input_dim, inner_dim1)
        self.activation_fn1 = activation_fn
        self.dropout2 = nn.Dropout(p=pooler_dropout)
        self.dense2 = nn.Linear(inner_dim1, inner_dim2)
        self.activation_fn2 = activation_fn
        self.dropout3 = nn.Dropout(p=pooler_dropout)
        self.dense3 = nn.Linear(inner_dim2, num_classes)
        self.activation_fn3 = nn.Sigmoid()

    def forward(self, x):
        y = self.pretrained_model.forward(x, last_state_only=True)
        y = reshape(self.pretrained_model.downsampling, y)
        y = y[0]  # take <s> token (equiv. to [CLS])
        y = self.dropout1(y)
        y = self.dense1(y)
        y = self.activation_fn1(y)
        y = self.dropout2(y)
        y = self.dense2(y)
        y = self.activation_fn2(y)
        y = self.dropout3(y)
        y = self.dense3(y)
        y = self.activation_fn3(y)
        return y
    
    def freeze_pretrained_model(self):
        for i, (name, param) in enumerate(self.pretrained_model.named_parameters()):
            param.requires_grad = False
            
    def unfreeze_pretrained_model(self):
        for i, (name, param) in enumerate(self.pretrained_model.named_parameters()):
            param.requires_grad = True

In [276]:
generator = torch.Generator()
generator.manual_seed(0)

dataset = MyDataSet(x, y, length)
dataset_size = len(dataset)
train_size = int(dataset_size * 0.8)
val_size = int(dataset_size * 0.1)
test_size = dataset_size - train_size - val_size
train_data, val_data, test_data = random_split(dataset=dataset,
                                               lengths=[train_size, val_size, test_size],
                                               generator=generator)
train_loader = DataLoader(train_data, batch_size=1, shuffle=True)
val_loader = DataLoader(val_data, batch_size=1, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1, shuffle=True)

In [277]:
model = Classifier(pretrained_model = roberta_base1,
                  input_dim = 768,
                  inner_dim1 = 96,
                  inner_dim2 = 64,
                  num_classes = 25,
                  activation_fn = nn.Sigmoid(),
                  pooler_dropout = 0.2
                  )
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.5)
scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer,
                                       step_size=5, gamma=0.1)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

Classifier(
  (pretrained_model): OctupleEncoder(
    (dropout_module): FairseqDropout()
    (embed_tokens): Embedding(1237, 768, padding_idx=1)
    (embed_positions): LearnedPositionalEmbedding(8194, 768, padding_idx=1)
    (emb_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (layers): ModuleList(
      (0): TransformerSentenceEncoderLayer(
        (dropout_module): FairseqDropout()
        (activation_dropout_module): FairseqDropout()
        (self_attn): MultiheadAttention(
          (dropout_module): FairseqDropout()
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=768, out_features=3072, bi

In [278]:
model.freeze_pretrained_model()
autograd.set_detect_anomaly(True)
epoch = 20
batch_size = 4
train_loss = 0
val_loss = 0
test_loss = 0
count = 0
for epoch in range(epoch):
    loss_sum = 0
    train_loss = 0
    count = 0
    model.train()
    for (i, data) in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        label = labels[0][:-1]
        y_pred = torch.Tensor(model(inputs.squeeze(dim = 0)))
        loss_sum = loss_sum + criterion(y_pred, label)
        count = count + 1
        if (len(train_loader) - i) % batch_size == 1:
            train_loss = train_loss + loss_sum.item()
            loss = loss_sum / count
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_sum = 0
            count = 0
    train_loss = train_loss / len(train_loader)
    model.eval()
    with torch.no_grad():
        loss_sum = 0
        for (i, data) in enumerate(val_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            label = labels[0][:-1]
            y_pred = torch.Tensor(model(inputs.squeeze(dim = 0)))
            loss = criterion(y_pred, label)
            loss_sum = loss_sum + loss.item()
        val_loss = loss_sum / len(val_loader)
    print("Epoch: {0}, Train loss: {1:0.5f}, Validation loss: {2:0.5f}".format(epoch, train_loss, val_loss))
    scheduler.step()
with torch.no_grad():
        loss_sum = 0
        for (i, data) in enumerate(test_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            label = labels[0][:-1]
            y_pred = torch.Tensor(model(inputs.squeeze(dim = 0)))
            loss = criterion(y_pred, label)
            loss_sum = loss_sum + loss.item()
        test_loss = loss_sum / len(test_loader)
print("Test loss: {0:0.5f}".format(test_loss))

Epoch: 0, Train loss: 0.03471, Validation loss: 0.03000
Epoch: 1, Train loss: 0.03036, Validation loss: 0.02775
Epoch: 2, Train loss: 0.02901, Validation loss: 0.02708
Epoch: 3, Train loss: 0.02858, Validation loss: 0.02687
Epoch: 4, Train loss: 0.02833, Validation loss: 0.02681
Epoch: 5, Train loss: 0.02844, Validation loss: 0.02680
Epoch: 6, Train loss: 0.02837, Validation loss: 0.02680
Epoch: 7, Train loss: 0.02837, Validation loss: 0.02680
Epoch: 8, Train loss: 0.02852, Validation loss: 0.02680
Epoch: 9, Train loss: 0.02831, Validation loss: 0.02680
Epoch: 10, Train loss: 0.02819, Validation loss: 0.02680
Epoch: 11, Train loss: 0.02825, Validation loss: 0.02680
Epoch: 12, Train loss: 0.02838, Validation loss: 0.02680
Epoch: 13, Train loss: 0.02830, Validation loss: 0.02680
Epoch: 14, Train loss: 0.02851, Validation loss: 0.02680
Epoch: 15, Train loss: 0.02841, Validation loss: 0.02680
Epoch: 16, Train loss: 0.02829, Validation loss: 0.02680
Epoch: 17, Train loss: 0.02835, Validatio