In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import ASTConfig, ASTModel, GPT2Config, GPT2Model, AutoModelForCausalLM, GPT2LMHeadModel

class GPT2Model(nn.Module):
    def __init__(self, vocab_size=140):
        super(GPT2Model, self).__init__()
        self.configuration = GPT2Config(vocab_size=vocab_size, bos_token_id=2, eos_token_id=1)
        self.model = GPT2LMHeadModel(self.configuration)
        
        # self.optimizer = Adam(self.model.parameters(), lr=3e-5)
        # self.scheduler = optim.lr_scheduler.LambdaLR(optimizer=self.optimizer, lr_lambda=lambda epoch: 0.97 ** epoch)
    
    def forward(self, input_ids, labels=None):
        attention_mask = self.make_mask(input_ids)
        output = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        return output
    
    def make_mask(self, input_ids):
        attention_mask = (input_ids != 0).long()
        return attention_mask
    
    def infer(self, input_ids, length=2048):
        if len(input_ids.shape) == 1:
            input_ids = input_ids.unsqueeze(0)
        if len(input_ids.shape) > 2:
            raise Exception
        
        if length > 2048:
            print("Max Length is 2048. Change Length Auto to 2048")
            length = 2048
        
        with torch.no_grad():
            for step in range(length):
                output = self.forward(input_ids)
                output = torch.argmax(output.logits, dim=2)

                predict = output[:,-1].unsqueeze(1)
                output_ids = torch.cat((input_ids, predict), dim=-1)

                input_ids = output_ids
                
                # if torch.all(predict.eq(0)):
                #     break
                
                if output_ids.shape[1] > 2048:
                    break

        return output_ids


In [None]:
import torch
import json
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split

class BPE_Chord_Dataset(Dataset):
    def __init__(self, data, base_model):
        super().__init__()
        self.data = data
        self.base_model = base_model
        # TODO
        vocab_path = '/workspace/pj/data/vocabs/chord_bpe_20000.json'
        with open(vocab_path, 'r') as file:
            self.vocab = json.load(file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text_seq = self.data[idx]
        
        if isinstance(text_seq, str):
            toks = text_seq.split()
        l_toks = len(toks)
        ratio = 4
        chord_list = []
        for idx in range(0, l_toks, ratio):
            t1, t2, t3, t4 = toks[idx : idx + 4]
            if t1[0] == 'h' or t1[0] == 'H':
                chord_list.append(t1)
                
        target_chord_seq = self.tokenizing_chord_seq(chord_list)
        # target_chord_seq = [i for i in range(10,110)]
        target_chord_tensor = [2] + target_chord_seq[:510] + [1]
        
        target_chord_tensor = torch.tensor(target_chord_tensor)
        
        return target_chord_tensor
    
    def tokenizing_chord_seq(self, chord_list):
        cur = 0
        cur_chord = chord_list[0]
        candidate = 1
        group_list = []
        while(cur < len(chord_list) and candidate < len(chord_list)):
            
            if cur_chord + chord_list[candidate] in self.vocab:
                cur_chord += chord_list[candidate]
                candidate += 1
                continue
            else:
                group_list.append(self.vocab[cur_chord])
                cur_chord = chord_list[candidate]
                cur = candidate
                candidate += 1
        group_list.append(self.vocab[cur_chord])
            
        return group_list
    
def create_BPE_dataloaders(batch_size, base_model):
    # chord_data = torch.load('../../../workspace/data/tensor/chord_tensor.pt')
    raw_data_path = '../../../workspace/pj/data/corpus/raw_corpus_bpe.txt'
    raw_data = []
    with open(raw_data_path, 'r') as f:
        for line in tqdm(f, desc="reading original txt file..."):
            raw_data.append(line.strip())
            
    # train, val_test = train_test_split(chord_data, test_size=0.1, random_state=5)
    train, val_test = train_test_split(raw_data, test_size=0.1, random_state=5)
    val, test = train_test_split(val_test, test_size=0.2, random_state=5)
    
    train_dataset = BPE_Chord_Dataset(train, base_model)
    val_dataset = BPE_Chord_Dataset(val, base_model)
    test_dataset = BPE_Chord_Dataset(test, base_model)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=BPE_collate_batch)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=BPE_collate_batch)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=BPE_collate_batch)

    return train_loader, val_loader, test_loader

def BPE_collate_batch(batch):
    padded = pad_sequence(batch, padding_value=0, batch_first=True)
    # return padded[:,:-1], padded[:,1:]
    return padded, padded




In [None]:
vocab_size = 20000
batch_size=2
device = torch.device("cuda:9" if torch.cuda.is_available() else "cpu")
print(device)

base_model = 'GPT2'
model = GPT2Model(vocab_size=vocab_size)
model.load_state_dict(torch.load('/workspace/pj/out/chord_bpe/GPT2_BPE_V20000/model_189_0.2063_0.3527.pt', map_location=device))
train_loader, val_loader, test_loader = create_BPE_dataloaders(batch_size, base_model)

In [None]:
cnt = 0
for (inputs, targets) in test_loader:
    sample_input = inputs
    sample_target = targets
    if cnt == 4:
        print(inputs)
        print(targets)
        break
    cnt += 1

In [None]:
sample_input = sample_input[:,:6]
print(sample_input)
sample_out = model.infer(sample_input, length=40)
print(sample_out)