# Define Model

In [69]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from transformers import T5Config, T5ForConditionalGeneration, AdamW
from transformers import BertConfig, BertModel, BertLMHeadModel, BartConfig, BartModel

path_prefix = '/workspace/out/chord_note'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

class Chord_Note_LSTM(nn.Module):
    def __init__(self, chord_size=150, note_size=832, chord_dim=512, note_dim=512, num_note=2, hidden_dim=512, num_layers=5):
        super(Chord_Note_LSTM, self).__init__()
        self.chord_embedding = nn.Embedding(chord_size, chord_dim, padding_idx=0)
        self.note_embedding = nn.Embedding(note_size, note_dim, padding_idx=0)
        self.num_note = num_note
        self.lstm = nn.LSTM(chord_dim+(note_dim*num_note), hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, chord_size)
        
        self.epsilon = 1e-9
        self.optimizer = optim.AdamW(self.parameters(), lr=0.01, eps=1e-9)
        self.scheduler = optim.lr_scheduler.LambdaLR(optimizer=self.optimizer, lr_lambda=lambda epoch: 0.95 ** epoch)
        # self.scheduler = CosineAnnealingWarmRestarts(self.optimizer, T_0=5, T_mult=2, eta_min=0.001)

    def forward(self, note, chord):
        chord_embed = self.chord_embedding(chord)
        note_embed = self.note_embedding(note)
        
        bsz, seq_length, depth, embed_dim = note_embed.shape
        # assert depth == self.num_note
        note_embed = note_embed.mean(dim=2)
        
        input_embed = torch.cat([chord_embed, note_embed], dim=2)
        
        output, (hidden, cell) = self.lstm(input_embed)
        
        logits = self.fc(output)
        
        return logits
    
    def infer(self, note, chord, length=2048):
        if len(chord.shape) == 1:
            chord = chord.unsqueeze(0)
        if len(chord.shape) > 2:
            raise Exception
        
        
        if length > 2048:
            print("Max Length is 2048. Change Length Auto to 2048")
            length = 2048
        
        with torch.no_grad():
            for step in range(length):
                chord_length = chord.shape[1]
                
                output = self.forward(note[:,:chord_length], chord)
                output = torch.argmax(output, dim=2)
                predict = output[:,-1].unsqueeze(1)
                
                output_ids = torch.cat((chord, predict), dim=-1)

                chord = output_ids
                
                # if torch.all(predict.eq(0)):
                #     break
                if chord_length == note.shape[1]:
                    print("MAX NOTE LENGTH")
                    break
                
                if predict[0,0] == 1:
                    print("EOS DETECT!!")
                    break
                
                if output_ids.shape[1] > 2048:
                    break
                
                if step % 50 == 0:
                    print(f'{step} Generate...')

        return output_ids
    
    
model = Chord_Note_LSTM(chord_dim=256, note_dim=256, num_note=1, hidden_dim=256, num_layers=3)
model.load_state_dict(torch.load(path_prefix + '/LSTM_4Note_Avg/model_72_0.3374_0.2445.pt', map_location=device))
model.to(device)
model.eval()



cuda:0


Chord_Note_LSTM(
  (chord_embedding): Embedding(150, 256, padding_idx=0)
  (note_embedding): Embedding(832, 256, padding_idx=0)
  (lstm): LSTM(512, 256, num_layers=3, batch_first=True)
  (fc): Linear(in_features=256, out_features=150, bias=True)
)

# Load Sample Dataset

In [9]:
batch_size = 1
base_model = 'LSTM'
n_notes = 4

In [10]:
import torch
import json
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split

class SymDataset(Dataset):
    def __init__(self, data, base_model, n_notes):
        super().__init__()
        self.data = data
        self.base_model = base_model
        self.n_notes = n_notes
        
        chord_vocab_path = '/workspace/data/vocabs/chord.json'
        with open(chord_vocab_path, 'r') as file:
            self.chord_vocab = json.load(file)
            
        note_vocab_path = '/workspace/data/vocabs/note.json'
        with open(note_vocab_path, 'r') as file:
            self.note_vocab = json.load(file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text_seq = self.data[idx]
        
        if isinstance(text_seq, str):
            toks = text_seq.split()
        l_toks = len(toks)
        ratio = 4
        
        measure_list = []
        chord_list = []
        note_list = []
        
        note_in_measure = []
        for idx in range(0, l_toks, ratio):
            # 4개 단위로 txt token 도는중
            t1, t2, t3, t4 = toks[idx : idx + 4]
            
            if t1[0] == 'm' or t1[0] == 'M':
                measure_list.append(t1)
                # 이전 마디 정보 update
                note_list.append(note_in_measure[-self.n_notes:])
                note_in_measure = []
            
            if t1[0] == 'h' or t1[0] == 'H':
                chord_list.append(t1)
                
            if t1 in self.note_vocab:
                if t1 not in note_in_measure:
                    note_in_measure.append(t1)
        # 마무리 note 정보 update
        # 최근 8개
        note_list.append(note_in_measure[-self.n_notes:])
        # 처음 M에서 들어간거 삭제
        note_list = note_list[1:767][:]
        
        # 코드는 Grouping 단위로 해주는 get_chord
        target_chord_seq = self.get_chord_seq(chord_list)
        # 문자에서 숫자로
        target_chord_tensor = [self.chord_vocab[chd] for chd in target_chord_seq]
        target_chord_tensor = [2] + target_chord_tensor[:766] + [1]
        target_chord_tensor = torch.tensor(target_chord_tensor)
        
        # Note 문자에서 숫자로
        target_note_tensor = self.get_note_seq(note_list)
        target_note_tensor = [[2]*self.n_notes] + target_note_tensor + [[1]*self.n_notes]
        target_note_tensor = torch.tensor(target_note_tensor)

        assert target_note_tensor.shape[0] == target_chord_tensor.shape[0]

        return target_chord_tensor, target_note_tensor
    
    
    def get_chord_seq(self, chord_list):
        group_list = []
        for idx in range(0, len(chord_list)):
            group_list.append(chord_list[idx])
            
        return group_list
    
    def get_note_seq(self, note_list):
        target_note_tensor = []
        for n_list in note_list:
            n_tensor = [0]*self.n_notes
            if len(n_list) > self.n_notes:
                raise Exception
            
            for idx, n in enumerate(n_list):
                vocab = self.note_vocab[n]
                n_tensor[idx] = vocab
            target_note_tensor.append(n_tensor)
        return target_note_tensor
    
def create_dataloaders(batch_size, base_model, n_notes):
    raw_data_path = '../../../workspace/data/corpus/raw_corpus_bpe.txt'
    raw_data = []
    with open(raw_data_path, 'r') as f:
        for line in tqdm(f, desc="reading original txt file..."):
            raw_data.append(line.strip())
            
    train, val_test = train_test_split(raw_data, test_size=0.1, random_state=5)
    val, test = train_test_split(val_test, test_size=0.2, random_state=5)
    
    train_dataset = SymDataset(train, base_model, n_notes)
    val_dataset = SymDataset(val, base_model, n_notes)
    test_dataset = SymDataset(test, base_model, n_notes)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_batch)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_batch)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_batch)

    return train_loader, val_loader, test_loader

def collate_batch(batch):
    chords, notes = zip(*batch)
    # padding_value = <eos>
    chord_padded = pad_sequence(chords, padding_value=0, batch_first=True)
    note_padded = pad_sequence(notes, padding_value=0, batch_first=True)
    return chord_padded[:,:-1], chord_padded[:,1:], note_padded[:,:-1,:]
    
train_loader, val_loader, test_loader = create_dataloaders(batch_size, base_model, n_notes)

reading original txt file...: 1198it [00:00, 4090.55it/s]

reading original txt file...: 46188it [00:10, 4244.87it/s]


# GET SAMPLE DATA

In [74]:
for i, o, note in train_loader:
    train_note = note.to(device)
    train_in_sample = i.to(device)
    train_out_sample =  o.to(device)
    break

for i, o, note in val_loader:
    val_note = note.to(device)
    val_in_sample = i.to(device)
    val_out_sample =  o.to(device)
    break

for i, o, note in test_loader:
    test_note = note.to(device)
    test_in_sample = i.to(device)
    test_out_sample =  o.to(device)
    break

print("Train IN sample")
print(train_in_sample)
print(train_in_sample.shape)
print("Train OUT sample")
print(train_out_sample)
print(train_out_sample.shape)
print("Train Note sample")
print(train_note)
print(train_note.shape)

Train IN sample
tensor([[  2,  15,  44,  92,  61,  15,  44,  92,  61,  16,  92,  93,  61,  16,
          92,  93,  62,  15,  21,  92,  61,  15,  21,  92,  61,  16,  92,  92,
          65,  16,  92,  92,  61, 117,  92,  92,  61,  61,  21,  92,  16,  15,
          92,  92,  61,  15,  92,  92,  61,  15,  34,  92,  16,  15,  92,  93,
          61,  16,  37,  92,  84, 117,  37,  92, 116, 117,   7,  92,  61,  16,
          37,   6,  59,  15,  93,  93,  61,  15,  93,  93,  61,  16,  92,  92,
          65,  16,  92,  92,  61,  16,  92,  92,  61,  61, 118,  92,  61,  16,
          36,  92,  61,  16, 122,  92,  61, 100,  34,  92,  61,  16,  36,   7,
          61,  15,  92,  92,  61,  15,  92,  92,  61,  15,  92,  92,  61,  15,
          92,  92,  61,  16,  92,  92,  61,  15,  40,  92,  61,  16,  92,  92,
          61,  15,  93,  92,  61,  16,  36,   6,  16,  16,  36,  93,  62,  16,
          36,  37,  62]], device='cuda:0')
torch.Size([1, 157])
Train OUT sample
tensor([[ 15,  44,  92,  61,  15, 

# Test Inference

In [47]:
print("Prompt Chord")
print(val_in_sample[:,:30])
print("Answer Chord")
print(val_out_sample)
print(val_out_sample.shape)
print("Prompt Note")
print(val_note)
print(val_note.shape)

Prompt Chord
tensor([[  2,   5,  52,  82,  29,   5,  52, 133,  29,  48,  52,  82,  29,   5,
          52,  74,  29,  50,  71, 128, 128,  48,  50,  81,  48,   5,   4,  25,
          80,  49]], device='cuda:0')
Answer Chord
tensor([[  5,  52,  82,  29,   5,  52, 133,  29,  48,  52,  82,  29,   5,  52,
          74,  29,  50,  71, 128, 128,  48,  50,  81,  48,   5,   4,  25,  80,
          49,  71,  70,  92,  70, 127,  51,  51,  70, 128,  51,  51,  33,  32,
           5,  52, 129,  29,  51,  52, 129,  82,  48,  69,  25,  82,  51,  69,
          25, 129,   5,  52, 129,  29,  51,  52, 129,  80,  26, 128,  26, 128,
          26, 128, 103, 128,  47,  47, 126, 126,  16, 117,  62,  62,  82, 102,
         132, 126,  82, 102, 132, 125, 127,   7,   7,  62,  80,   7,  62,  62,
          80,   7,   7,  62,  80,   7,  93,  62,   7,  37, 133,  48,  52,  82,
         115,   5,  52, 107,  29,  51,  52,   5,  29,  51,  52, 133, 115,   5,
          50,   1]], device='cuda:0')
torch.Size([1, 128])
Prompt N

In [81]:
sample_output = model.infer(train_note, train_in_sample[:,:30])
print(sample_output)
print(sample_output.shape)

0 Generate...
50 Generate...
EOS DETECT!!
tensor([[ 2, 15, 44, 92, 61, 15, 44, 92, 61, 16, 92, 93, 61, 16, 92, 93, 62, 15,
         21, 92, 61, 15, 21, 92, 61, 16, 92, 92, 65, 16, 92, 92, 61, 15, 92, 92,
         61, 15, 92, 92, 61, 15, 92, 92, 61, 15, 92, 92, 61, 15, 92, 92, 61, 15,
         92, 92, 61, 15, 92, 92, 59, 15, 92, 92, 43, 21, 92, 92, 61, 15, 92, 92,
         61, 15, 92, 92, 61, 15, 92, 92, 61, 15, 92, 92, 61, 15, 92, 92, 61, 15,
         92, 92, 61, 15, 92, 92, 61, 15, 92, 92, 61, 15, 92, 92, 61, 15, 92, 92,
         61, 15, 92, 92, 61, 15, 92, 92, 61, 15, 92, 92,  1]], device='cuda:0')
torch.Size([1, 121])
