In [13]:
import torch
import json
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split

class C2IDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data
        inst_vocab_path = '/workspace/pj/data/vocabs/inst.json'
        chord_vocab_path = '/workspace/pj/data/vocabs/chord.json'
        with open(inst_vocab_path, 'r') as file:
            self.inst_vocab = json.load(file)
        with open(chord_vocab_path, 'r') as file:
            self.chord_vocab = json.load(file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text_seq = self.data[idx]
        
        if isinstance(text_seq, str):
            toks = text_seq.split()
            
        l_toks = len(toks)
        ratio = 4
        chord_list = []
        inst_in_measure = []
        inst_list = []
        
        inst_tensor = torch.zeros(133)
        
        for idx in range(0, l_toks, ratio):
            t1, t2, t3, t4 = toks[idx : idx + 4]
            if t1[0] == 'H':
                chord_list.append(t1)

            if t4[0] == 'x' or t4[0] == 'X' or t4[0] == 'y' or t4 == '<unk>':
                inst_tensor[self.inst_vocab[t4]] = 1
        
        chord_tensor = [self.chord_vocab[chd] for chd in chord_list]
        
        target_chord_tensor = [2] + chord_tensor[:766] + [1]
        target_chord_tensor = torch.tensor(target_chord_tensor)
        
        target_inst_tensor = inst_tensor

        return target_chord_tensor, target_inst_tensor, target_chord_tensor.shape[0]
    
def create_dataloaders(batch_size):
    raw_data_path = '../../../workspace/pj/data/corpus/raw_corpus_bpe.txt'
    # raw_data_path = '../../../workspace/data/corpus/first_5_lines_bpe.txt'
    raw_data = []
    with open(raw_data_path, 'r') as f:
        for line in tqdm(f, desc="reading original txt file..."):
            raw_data.append(line.strip())
            
    train, val_test = train_test_split(raw_data, test_size=0.1, random_state=5)
    val, test = train_test_split(val_test, test_size=0.2, random_state=5)
    # train, val_test = train_test_split(raw_data, test_size=0.5, random_state=5)
    # val, test = train_test_split(val_test, test_size=0.2)
    
    train_dataset = InstDataset(train)
    val_dataset = InstDataset(val)
    test_dataset = InstDataset(test)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_batch)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_batch)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_batch)

    # return train_loader, True, True
    return train_loader, val_loader, test_loader

def collate_batch(batch):
    chords, insts, length = zip(*batch)
    # padding_value = <eos>
    chord_padded = pad_sequence(chords, padding_value=0, batch_first=True)
    inst_padded = pad_sequence(insts, padding_value=0, batch_first=True)
    length_padded = pad_sequence(length, padding_value=0, batch_first=True)
    return chord_padded, inst_padded, length_padded

def create_C2I(batch_size):
    raw_data_path = '../../../workspace/pj/data/corpus/raw_corpus_bpe.txt'
    # raw_data_path = '../../../workspace/data/corpus/first_5_lines_bpe.txt'
    raw_data = []
    with open(raw_data_path, 'r') as f:
        for line in tqdm(f, desc="reading original txt file..."):
            raw_data.append(line.strip())
            
    train, val_test = train_test_split(raw_data, test_size=0.1, random_state=5)
    val, test = train_test_split(val_test, test_size=0.2, random_state=5)
    
    train_dataset = C2IDataset(train)
    val_dataset = C2IDataset(val)
    test_dataset = C2IDataset(test)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_batch_C2I)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_batch_C2I)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_batch_C2I)

    # return train_loader, True, True
    return train_loader, val_loader, test_loader

def collate_batch_C2I(batch):
    chords, insts, length = zip(*batch)
    # padding_value = <eos>
    chord_padded = pad_sequence(chords, padding_value=0, batch_first=True)
    inst_padded = pad_sequence(insts, padding_value=0, batch_first=True)
    return chord_padded, inst_padded, length

torch.set_printoptions(profile="full")

In [14]:
train_loader, val_loader, test_loader = create_C2I(1)

reading original txt file...: 46188it [00:11, 3875.35it/s]


In [24]:
inst_sum = torch.zeros(133).long()
print(inst_sum)
cnt = 0

for (chords, targets, lengths) in tqdm(train_loader, ncols=60):
    inst_sum += targets.squeeze(0).long()

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])


100%|█████████████████| 41569/41569 [27:41<00:00, 25.03it/s]


In [25]:
for (chords, targets, lengths) in tqdm(val_loader, ncols=60):
    inst_sum += targets.squeeze(0).long()

  0%|                              | 0/3695 [00:00<?, ?it/s]

100%|███████████████████| 3695/3695 [02:23<00:00, 25.67it/s]


In [26]:
for (chords, targets, lengths) in tqdm(test_loader, ncols=60):
    inst_sum += targets.squeeze(0).long()

  0%|                               | 0/924 [00:00<?, ?it/s]

100%|█████████████████████| 924/924 [00:37<00:00, 24.73it/s]


In [23]:
print(inst_sum.long())

tensor([    0,     0,     0,     0, 25758,   723,   300,   342,   709,   195,
          862,   157,  1264, 11379,   730,  6129,  5857,  4921,  6224,   236,
          285,   229,   289,  1565,   277,   728,   577,   122,  1843,  1734,
          562,  2065,   415,   713,   652,   115,  4529,  3615,  2193,   355,
          146,   145,   374,   227, 16963, 11370, 14240, 14010,  1379,  8758,
         9241, 26186, 19252,  1011,  1029,   271,  6613,   450,   190,   208,
        36920, 30527, 25915,  1596, 34516,  1553,   563,   225,  2156, 16172,
        13834, 10948, 28672,  3767, 25725, 34334, 13628, 38641,   908,   579,
          112,    94,   417,   567,   669,   675,   191,    84,    37,   180,
           23,   358,   281,   268,   124,   251,   125,   120,   130,    92,
          258,    66,    96,   154,   151,    55,    78,    85,   144,   427,
           64,   192,   120,   206,   146,    39,   343,    15,   196,    58,
          254,   244,    29,   232,    11,    29,    65,    20, 

In [29]:
neg = 46188 - inst_sum
print(neg)

tensor([46188, 46188, 46188, 46188, 20430, 45465, 45888, 45846, 45479, 45993,
        45326, 46031, 44924, 34809, 45458, 40059, 40331, 41267, 39964, 45952,
        45903, 45959, 45899, 44623, 45911, 45460, 45611, 46066, 44345, 44454,
        45626, 44123, 45773, 45475, 45536, 46073, 41659, 42573, 43995, 45833,
        46042, 46043, 45814, 45961, 29225, 34818, 31948, 32178, 44809, 37430,
        36947, 20002, 26936, 45177, 45159, 45917, 39575, 45738, 45998, 45980,
         9268, 15661, 20273, 44592, 11672, 44635, 45625, 45963, 44032, 30016,
        32354, 35240, 17516, 42421, 20463, 11854, 32560,  7547, 45280, 45609,
        46076, 46094, 45771, 45621, 45519, 45513, 45997, 46104, 46151, 46008,
        46165, 45830, 45907, 45920, 46064, 45937, 46063, 46068, 46058, 46096,
        45930, 46122, 46092, 46034, 46037, 46133, 46110, 46103, 46044, 45761,
        46124, 45996, 46068, 45982, 46042, 46149, 45845, 46173, 45992, 46130,
        45934, 45944, 46159, 45956, 46177, 46159, 46123, 46168, 

In [31]:
pos_weight = neg / inst_sum
print(pos_weight.long())

tensor([-9223372036854775808, -9223372036854775808, -9223372036854775808,
        -9223372036854775808,                    0,                   62,
                         152,                  134,                   64,
                         235,                   52,                  293,
                          35,                    3,                   62,
                           6,                    6,                    8,
                           6,                  194,                  161,
                         200,                  158,                   28,
                         165,                   62,                   79,
                         377,                   24,                   25,
                          81,                   21,                  110,
                          63,                   69,                  400,
                           9,                   11,                   20,
                         129,         

In [27]:
print(inst_sum)

tensor([    0,     0,     0,     0, 25758,   723,   300,   342,   709,   195,
          862,   157,  1264, 11379,   730,  6129,  5857,  4921,  6224,   236,
          285,   229,   289,  1565,   277,   728,   577,   122,  1843,  1734,
          562,  2065,   415,   713,   652,   115,  4529,  3615,  2193,   355,
          146,   145,   374,   227, 16963, 11370, 14240, 14010,  1379,  8758,
         9241, 26186, 19252,  1011,  1029,   271,  6613,   450,   190,   208,
        36920, 30527, 25915,  1596, 34516,  1553,   563,   225,  2156, 16172,
        13834, 10948, 28672,  3767, 25725, 34334, 13628, 38641,   908,   579,
          112,    94,   417,   567,   669,   675,   191,    84,    37,   180,
           23,   358,   281,   268,   124,   251,   125,   120,   130,    92,
          258,    66,    96,   154,   151,    55,    78,    85,   144,   427,
           64,   192,   120,   206,   146,    39,   343,    15,   196,    58,
          254,   244,    29,   232,    11,    29,    65,    20, 