# 모델 & 데이터 로딩

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2LMHeadModel
'argon2:$argon2id$v=19$m=10240,t=10,p=8$gaqXFKhxwobwEqSGu91mOw$5JAzPFFjzbGA/HfxZ2JpAlLlBdh+uD4tIL1FbiCRFdE'
class GPT2Model(nn.Module):
    def __init__(self, vocab_size=140, n_embd=768, n_layer=12, n_head=12):
        super(GPT2Model, self).__init__()
        self.configuration = GPT2Config(vocab_size=vocab_size, n_embd=n_embd, n_layer=n_layer, n_head=n_head, bos_token_id=2, eos_token_id=1)
        self.model = GPT2LMHeadModel(self.configuration)
        
    def get_embed(self, idx):
        embedding_layer = self.model.transformer.wte
        token_embedding = embedding_layer(torch.tensor([idx]))
        return token_embedding
    
    def extract_vocab_embeddings(self):
        # Extract all the embeddings for the entire vocabulary
        embedding_layer = self.model.transformer.wte
        vocab_embeddings = embedding_layer.weight.detach().clone()
        return vocab_embeddings

    def forward(self, input_ids, labels=None, return_hidden_states=False):
        attention_mask = self.make_mask(input_ids)
        # Forward pass through the transformer to get hidden states
        transformer_outputs = self.model.transformer(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)

        # Extract hidden states before the projection
        hidden_states = transformer_outputs.last_hidden_state
        
        if return_hidden_states:
            return hidden_states

        # Project the hidden states to vocabulary size
        logits = self.model.lm_head(hidden_states)

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.configuration.vocab_size), labels.view(-1))
            return loss, logits
        return logits

    def make_mask(self, input_ids):
        attention_mask = (input_ids != 0).long()
        return attention_mask
    
    def infer(self, input_ids, length=2048):
        if len(input_ids.shape) == 1:
            input_ids = input_ids.unsqueeze(0)
        if len(input_ids.shape) > 2:
            raise Exception
        
        if length > 2048:
            print("Max Length is 2048. Change Length Auto to 2048")
            length = 2048
        
        with torch.no_grad():
            for step in range(length):
                logits = self.forward(input_ids)
                output = torch.argmax(logits, dim=2)

                predict = output[:,-1].unsqueeze(1)
                output_ids = torch.cat((input_ids, predict), dim=-1)

                input_ids = output_ids
                
                if output_ids.shape[1] > 2048:
                    break

        return output_ids

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.expand = nn.Linear(input_dim, hidden_dim*2)
        self.dropex = nn.Dropout(0.15)
        
        self.fc1 = nn.Linear(hidden_dim*2, hidden_dim//2)
        self.drop1 = nn.Dropout(0.15)
        self.fc2 = nn.Linear(hidden_dim//2, hidden_dim//4)
        self.drop2 = nn.Dropout(0.15)
        self.fc3 = nn.Linear(hidden_dim//4, hidden_dim//2)
        self.drop3 = nn.Dropout(0.15)
        self.fc4 = nn.Linear(hidden_dim//2, hidden_dim)

        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
    
    def forward(self, x):
        # x: [bsz, 133] input binary vector
        x = x.float()
        h = F.relu(self.expand(x))
        h = self.dropex(h)
        
        h = F.relu(self.fc1(h))  # [bsz, hidden_dim]
        h = self.drop1(h)
        h = F.relu(self.fc2(h))
        h = self.drop2(h)
        h = F.relu(self.fc3(h))
        h = self.drop3(h)
        h = F.relu(self.fc4(h))

        mu = self.fc_mu(h)       # [bsz, latent_dim]
        logvar = self.fc_logvar(h)  # [bsz, latent_dim]
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, hidden_dim)
        self.rnn = nn.GRU(hidden_dim, output_dim, num_layers=3, batch_first=True)
    
    def forward(self, z, seq_lens, chord_embedding):
        # z: [bsz, latent_dim] sampled latent vector
        h = F.relu(self.fc(z)).unsqueeze(1)  # [bsz, 1, hidden_dim]
        
        # Prepare initial hidden state for GRU
        max_len = max(seq_lens)  # Maximum sequence length in the batch
        h_repeat = h.expand(-1, max_len, -1)  # [bsz, max_len, hidden_dim]
        chord_embedding = chord_embedding/100
        h_repeat = h_repeat + chord_embedding
        # Initialize an empty tensor to hold output sequences with different lengths
        outputs = []
        for i, seq_len in enumerate(seq_lens):
            # Use GRU to generate sequences of length `seq_len`
            out, _ = self.rnn(h_repeat[i:i+1, :seq_len])  # [1, seq_len, 133]
            outputs.append(out.squeeze(0))  # [seq_len, 133]
        
        return outputs  # List of tensors with shapes [seq_len_i, 133]

class VAE(nn.Module):
    def __init__(self, input_dim, encoder_hdim, decoder_hdim, latent_dim, output_dim, device):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_dim, encoder_hdim, latent_dim)
        self.decoder = Decoder(latent_dim, decoder_hdim, output_dim)
        
        self.device = device
        self.chord_encoder = GPT2Model(vocab_size=150)
        self.chord_encoder.load_state_dict(torch.load('/workspace/out/chord_bpe/GPT2_BPE_V150/model_207_0.4520_0.3645.pt'))
        # Freeze the chord_transformer parameters
        for param in self.chord_encoder.parameters():
            param.requires_grad = False
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x, chord_tensor, seq_lens):
        # x: [bsz, 133], seq_lens: list of output sequence lengths
        mu, logvar = self.encoder(x)  # Encode to get mu and logvar
        
        z = self.reparameterize(mu, logvar)  # Sample from the latent space
        
        chord_embedding = self.chord_encoder(chord_tensor, return_hidden_states=True)
        chord_embedding = chord_embedding[:,1:-1,:]
        output = self.decoder(z, seq_lens, chord_embedding)   # Decode to variable-length sequences
        return output, mu, logvar

# 그냥 버전에서는 인코더에 시작악기 벡터 넣을때 dim으로 보내지 말고 그냥 히든 원핫으로 넣어보는것도 생각

# 인코더 그냥 간단하게 리니어하게 z보낸 다음에 디코더로 복원할때 레이ㅓㄴ트 벡터 늘려서 rnn말고 "길이" 로 늘려서 대신에 리니어로 보낸 같은 레이턴트 벡터 길이만큼 있을때 각 벡터들 에 포지션 인코딩 느낌으로 넣는거 ㄱㅊ을듯ㄴ

In [3]:
import torch
import json
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split

class InstDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data
        inst_vocab_path = '/workspace/pj/data/vocabs/inst.json'
        chord_vocab_path = '/workspace/pj/data/vocabs/chord.json'
        with open(inst_vocab_path, 'r') as file:
            self.inst_vocab = json.load(file)
        with open(chord_vocab_path, 'r') as file:
            self.chord_vocab = json.load(file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text_seq = self.data[idx]
        
        if isinstance(text_seq, str):
            toks = text_seq.split()
            
        l_toks = len(toks)
        ratio = 4
        chord_list = []
        inst_in_measure = []
        inst_list = []
        
        for idx in range(0, l_toks, ratio):
            t1, t2, t3, t4 = toks[idx : idx + 4]
            if t1[0] == 'H':
                chord_list.append(t1)

            if t4[0] == 'x' or t4[0] == 'X' or t4[0] == 'y' or t4 == '<unk>':
                inst_in_measure.append(t4)
                
            if (t1[0] == 'm' or t1[0] == 'M') and len(chord_list) > 0:
                inst_list.append(inst_in_measure)
                inst_in_measure = []
        inst_list.append(inst_in_measure)
        
        chord_tensor = [self.chord_vocab[chd] for chd in chord_list]
        inst_tensor, length = self.convert_inst_to_onehot(inst_list)
        
        target_chord_tensor = [2] + chord_tensor[:766] + [1]
        target_chord_tensor = torch.tensor(target_chord_tensor)
        
        target_inst_tensor = inst_tensor

        return target_chord_tensor, target_inst_tensor, length+2
    
    def convert_inst_to_onehot(self, inst_list):
        base_tensor = torch.zeros(len(inst_list), 133)
        bos_tensor = torch.zeros(1, 133)
        eos_tensor = torch.zeros(1, 133)
        bos_tensor[:,2] = 1
        eos_tensor[:,1] = 1
        
        for idx, inst_in_measure in enumerate(inst_list):
            if len(inst_in_measure) == 0:
                continue
            else:
                for inst in inst_in_measure:
                    base_tensor[idx, self.inst_vocab[inst]] = 1
        inst_tensor = torch.cat((bos_tensor,base_tensor[:766,:],eos_tensor), dim=0)
        return inst_tensor, len(inst_list)
  
class InstGroupDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data
        inst_vocab_path = '/workspace/data/vocabs/inst_group60_vocab.json'
        chord_vocab_path = '/workspace/data/vocabs/chord.json'
        with open(inst_vocab_path, 'r') as file:
            self.inst_vocab = json.load(file)
        with open(chord_vocab_path, 'r') as file:
            self.chord_vocab = json.load(file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text_seq = self.data[idx]
        
        if isinstance(text_seq, str):
            toks = text_seq.split()
            
        l_toks = len(toks)
        ratio = 4
        chord_list = []
        inst_in_measure = []
        inst_list = []
        
        for idx in range(0, l_toks, ratio):
            t1, t2, t3, t4 = toks[idx : idx + 4]
            if t1[0] == 'H':
                chord_list.append(t1)

            if t4[0] == 'x' or t4[0] == 'X' or t4[0] == 'y' or t4 == '<unk>':
                inst_in_measure.append(t4)
                
            if (t1[0] == 'm' or t1[0] == 'M') and len(chord_list) > 0:
                inst_list.append(inst_in_measure)
                inst_in_measure = []
        inst_list.append(inst_in_measure)
        
        chord_tensor = [self.chord_vocab[chd] for chd in chord_list]
        inst_tensor, length = self.convert_inst_to_onehot(inst_list)
        
        target_chord_tensor = [2] + chord_tensor[:510] + [1]
        target_chord_tensor = torch.tensor(target_chord_tensor)
        
        target_inst_tensor = [2] + inst_tensor[:510] + [1]
        length = len(target_inst_tensor)
        target_inst_tensor = torch.tensor(target_inst_tensor)

        return target_chord_tensor, target_inst_tensor, torch.tensor([length])
    
    def convert_inst_to_onehot(self, inst_list):
        inst_vocab = []
        
        for insts in inst_list:
            group_inst = ''
            for inst in insts:
                if inst in group_inst:
                    pass
                else:
                    group_inst += inst
            
            try:
                inst_vocab.append(self.inst_vocab[group_inst])
            except:
                inst_vocab.append(3)
                
        return inst_vocab, len(inst_list)
    
class C2IDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data
        inst_vocab_path = '/workspace/data/vocabs/inst.json'
        chord_vocab_path = '/workspace/data/vocabs/chord.json'
        with open(inst_vocab_path, 'r') as file:
            self.inst_vocab = json.load(file)
        with open(chord_vocab_path, 'r') as file:
            self.chord_vocab = json.load(file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text_seq = self.data[idx]
        
        if isinstance(text_seq, str):
            toks = text_seq.split()
            
        l_toks = len(toks)
        ratio = 4
        chord_list = []
        inst_in_measure = []
        inst_list = []
        
        inst_tensor = torch.zeros(133)
        
        for idx in range(0, l_toks, ratio):
            t1, t2, t3, t4 = toks[idx : idx + 4]
            if t1[0] == 'H':
                chord_list.append(t1)

            if t4[0] == 'x' or t4[0] == 'X' or t4[0] == 'y' or t4 == '<unk>':
                inst_tensor[self.inst_vocab[t4]] = 1
        
        chord_tensor = [self.chord_vocab[chd] for chd in chord_list]
        
        target_chord_tensor = [2] + chord_tensor[:766] + [1]
        target_chord_tensor = torch.tensor(target_chord_tensor)
        
        target_inst_tensor = inst_tensor

        return target_chord_tensor, target_inst_tensor, torch.tensor([target_chord_tensor.shape[0]])
    
class InstGRUDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data
        inst_vocab_path = '/workspace/data/vocabs/inst.json'
        chord_vocab_path = '/workspace/data/vocabs/chord.json'
        with open(inst_vocab_path, 'r') as file:
            self.inst_vocab = json.load(file)
        with open(chord_vocab_path, 'r') as file:
            self.chord_vocab = json.load(file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text_seq = self.data[idx]
        
        if isinstance(text_seq, str):
            toks = text_seq.split()
            
        l_toks = len(toks)
        ratio = 4
        chord_list = []
        ans_inst_container = []
        inst_in_measure = []
        inst_list = []
        inst_tensor = torch.zeros(133)
        
        for idx in range(0, l_toks, ratio):
            t1, t2, t3, t4 = toks[idx : idx + 4]
            if t1[0] == 'h' or t1[0] == 'H':
                chord_list.append(t1)

            if t4[0] == 'x' or t4[0] == 'X' or t4[0] == 'y' or t4 == '<unk>':
                inst_tensor[self.inst_vocab[t4]] = 1
                
            if t4[0] == 'x' or t4[0] == 'X' or t4[0] == 'y' or t4 == '<unk>':
                inst_in_measure.append(t4)
                
            if (t1[0] == 'm' or t1[0] == 'M') and len(chord_list) > 0:
                inst_list.append(inst_in_measure)
                inst_in_measure = []
        
        inst_list.append(inst_in_measure)

        chord_tensor = [self.chord_vocab[chd] for chd in chord_list]
        ans_inst_container = self.convert_inst_to_onehot(inst_list, ans_inst_container)
        
        target_chord_tensor = [2] + chord_tensor[:766] + [1]
        target_chord_tensor = torch.tensor(target_chord_tensor)
        
        init_inst_tensor = inst_tensor

        return target_chord_tensor, init_inst_tensor, ans_inst_container
    
    def convert_inst_to_onehot(self, inst_list, ans_inst_container):
        
        for _ in range(133):
            ans_inst_container.append([0]*len(inst_list))
        
        for idx, inst_in_measure in enumerate(inst_list):
            if len(inst_in_measure) == 0:
                continue
            else:
                for inst in inst_in_measure:
                    # base_tensor[idx, self.inst_vocab[inst]] = 1
                    ans_inst_container[self.inst_vocab[inst]][idx] = 1
                    
        for idx, vec in enumerate(ans_inst_container):
            ans_inst_container[idx] = torch.tensor(vec[:766])
        
        return ans_inst_container
   
class InstVAEDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data
        inst_vocab_path = '/workspace/data/vocabs/inst.json'
        chord_vocab_path = '/workspace/data/vocabs/chord.json'
        with open(inst_vocab_path, 'r') as file:
            self.inst_vocab = json.load(file)
        with open(chord_vocab_path, 'r') as file:
            self.chord_vocab = json.load(file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text_seq = self.data[idx]
        
        if isinstance(text_seq, str):
            toks = text_seq.split()
            
        l_toks = len(toks)
        ratio = 4
        chord_list = []
        init_instruments = [0]*133
        inst_in_measure = []
        inst_list = []
        
        for idx in range(0, l_toks, ratio):
            t1, t2, t3, t4 = toks[idx : idx + 4]
            if t1[0] == 'H':
                chord_list.append(t1)

            if t4[0] == 'x' or t4[0] == 'X' or t4[0] == 'y' or t4 == '<unk>':
                inst_in_measure.append(t4)
                
                init_instruments[self.inst_vocab[t4]] = 1
                
            if (t1[0] == 'm' or t1[0] == 'M') and len(chord_list) > 0:
                inst_list.append(inst_in_measure)
                inst_in_measure = []
        inst_list.append(inst_in_measure)
        
        chord_tensor = [self.chord_vocab[chd] for chd in chord_list]
        inst_tensor, length = self.convert_inst_to_onehot(inst_list)
        
        target_chord_tensor = [2] + chord_tensor[:766] + [1]
        target_chord_tensor = torch.tensor(target_chord_tensor)
        
        target_init_inst_tensor = torch.tensor(init_instruments)
        
        target_inst_tensor = inst_tensor

        return target_chord_tensor, target_init_inst_tensor, target_inst_tensor, length
    
    def convert_inst_to_onehot(self, inst_list):
        base_tensor = torch.zeros(len(inst_list), 133)
        
        for idx, inst_in_measure in enumerate(inst_list):
            if len(inst_in_measure) == 0:
                continue
            else:
                for inst in inst_in_measure:
                    base_tensor[idx, self.inst_vocab[inst]] = 1
        inst_tensor = base_tensor[:766,:]
        return inst_tensor, len(inst_tensor)
    
def create_dataloaders(batch_size):
    raw_data_path = '../../../workspace/data/corpus/raw_corpus_bpe.txt'
    # raw_data_path = '../../../workspace/data/corpus/first_5_lines_bpe.txt'
    raw_data = []
    with open(raw_data_path, 'r') as f:
        for line in tqdm(f, desc="reading original txt file..."):
            raw_data.append(line.strip())
            
    train, val_test = train_test_split(raw_data, test_size=0.1, random_state=5)
    val, test = train_test_split(val_test, test_size=0.2, random_state=5)
    # train, val_test = train_test_split(raw_data, test_size=0.5, random_state=5)
    # val, test = train_test_split(val_test, test_size=0.2)
    
    train_dataset = InstDataset(train)
    val_dataset = InstDataset(val)
    test_dataset = InstDataset(test)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_batch)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_batch)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_batch)

    # return train_loader, True, True
    return train_loader, val_loader, test_loader

def collate_batch(batch):
    chords, insts, length = zip(*batch)
    # padding_value = <eos>
    chord_padded = pad_sequence(chords, padding_value=0, batch_first=True)
    inst_padded = pad_sequence(insts, padding_value=0, batch_first=True)
    length_padded = pad_sequence(length, padding_value=0, batch_first=True)
    return chord_padded, inst_padded, length_padded

def create_C2I(batch_size):
    raw_data_path = '../../../workspace/data/corpus/raw_corpus_bpe.txt'
    # raw_data_path = '../../../workspace/data/corpus/first_5_lines_bpe.txt'
    raw_data = []
    with open(raw_data_path, 'r') as f:
        for line in tqdm(f, desc="reading original txt file..."):
            raw_data.append(line.strip())
            
    train, val_test = train_test_split(raw_data, test_size=0.1, random_state=5)
    val, test = train_test_split(val_test, test_size=0.2, random_state=5)
    
    train_dataset = C2IDataset(train)
    val_dataset = C2IDataset(val)
    test_dataset = C2IDataset(test)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_batch_C2I)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_batch_C2I)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_batch_C2I)

    # return train_loader, True, True
    return train_loader, val_loader, test_loader

def collate_batch_C2I(batch):
    chords, insts, length = zip(*batch)
    
    # padding_value = <eos>
    chord_padded = pad_sequence(chords, padding_value=0, batch_first=True)
    inst_padded = pad_sequence(insts, padding_value=0, batch_first=True)
    length_padded = pad_sequence(length, padding_value=0, batch_first=True)

    return chord_padded, inst_padded, length_padded


def create_Group(batch_size):
    raw_data_path = '../../../workspace/data/corpus/raw_corpus_bpe.txt'
    # raw_data_path = '../../../workspace/data/corpus/first_5_lines_bpe.txt'
    raw_data = []
    with open(raw_data_path, 'r') as f:
        for line in tqdm(f, desc="reading original txt file..."):
            raw_data.append(line.strip())
            
    train, val_test = train_test_split(raw_data, test_size=0.1, random_state=5)
    val, test = train_test_split(val_test, test_size=0.2, random_state=5)
    
    train_dataset = InstGroupDataset(train)
    val_dataset = InstGroupDataset(val)
    test_dataset = InstGroupDataset(test)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_batch_Group)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_batch_Group)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_batch_Group)

    # return train_loader, True, True
    return train_loader, val_loader, test_loader

def collate_batch_Group(batch):
    chords, insts, length = zip(*batch)
    
    # padding_value = <eos>
    chord_padded = pad_sequence(chords, padding_value=0, batch_first=True)
    inst_padded = pad_sequence(insts, padding_value=0, batch_first=True)
    length_padded = pad_sequence(length, padding_value=0, batch_first=True)

    return chord_padded, inst_padded, length_padded

def create_InstGRU(batch_size):
    raw_data_path = '../../../workspace/data/corpus/raw_corpus_bpe.txt'
    # raw_data_path = '../../../workspace/data/corpus/first_5_lines_bpe.txt'
    raw_data = []
    with open(raw_data_path, 'r') as f:
        for line in tqdm(f, desc="reading original txt file..."):
            raw_data.append(line.strip())
            
    train, val_test = train_test_split(raw_data, test_size=0.1, random_state=5)
    val, test = train_test_split(val_test, test_size=0.2, random_state=5)
    
    train_dataset = InstGRUDataset(train)
    val_dataset = InstGRUDataset(val)
    test_dataset = InstGRUDataset(test)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_batch_InstGRU)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_batch_InstGRU)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_batch_InstGRU)

    # return train_loader, True, True
    return train_loader, val_loader, test_loader

def collate_batch_InstGRU(batch):
    target_chord_tensor, init_inst_tensor, ans_inst_container = zip(*batch)
    # padding_value = <eos>
    chord_padded = pad_sequence(target_chord_tensor, padding_value=0, batch_first=True)
    inst_padded = pad_sequence(init_inst_tensor, padding_value=0, batch_first=True)
    # ans_padded = pad_sequence(ans_inst_container, padding_value=0, batch_first=True)

    return chord_padded, inst_padded, ans_inst_container


def create_VAE(batch_size):
    raw_data_path = '../../../workspace/data/corpus/raw_corpus_bpe.txt'
    raw_data = []
    with open(raw_data_path, 'r') as f:
        for line in tqdm(f, desc="reading original txt file..."):
            raw_data.append(line.strip())
            
    train, val_test = train_test_split(raw_data, test_size=0.1, random_state=5)
    val, test = train_test_split(val_test, test_size=0.2, random_state=5)
    
    train_dataset = InstVAEDataset(train)
    val_dataset = InstVAEDataset(val)
    test_dataset = InstVAEDataset(test)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_batch_VAE)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_batch_VAE)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_batch_VAE)

    return train_loader, val_loader, test_loader

def collate_batch_VAE(batch):
    target_chord_tensor, target_init_inst_tensor, target_inst_tensor, length = zip(*batch)
    chord_padded = pad_sequence(target_chord_tensor, padding_value=0, batch_first=True)
    init_padded = pad_sequence(target_init_inst_tensor, padding_value=155, batch_first=True)
    inst_padded = pad_sequence(target_inst_tensor, padding_value=0, batch_first=True)
    # length_padded = pad_sequence(length, padding_value=0, batch_first=True)
    return chord_padded, init_padded, inst_padded, length


# train_loader, val_loader, test_loader = create_VAE(3)
# for chord, init, inst, length in train_loader:
#     print(chord)
#     print(chord.shape)
#     print(init)
#     print(init.shape)
#     print(inst)
#     print(inst.shape)
    
#     print(length)
#     break

In [4]:
def calculate_accuracy(output, target):
    output = torch.sigmoid(output)  # Apply sigmoid to get probabilities
    predictions = (output > 0.3).float()  # Convert probabilities to binary (0 or 1)
    correct = (predictions == target).float().sum()
    # accuracy = correct / target.numel()
    # return accuracy.item()
    return correct, target.numel()

def vae_loss_function(recon_x, x, mu, logvar, seq_lens):
    # recon_x: list of [seq_len_i, 133] (decoded sequences)
    # x: 
    # seq_lens: list of sequence lengths
    # mu, logvar: VAE latent distribution parameters

    # Reconstruction loss: Binary Cross-Entropy (or can use MSE)
    recon_loss = 0
    
    correct_total = 0
    cnt_total = 0
    
    for i, seq_len in enumerate(seq_lens):
        # Repeat the input across the sequence length to compare with output
        x_repeated = x[i, :seq_len, :]  # [seq_len, 133]
        # print(x_repeated.shape)
        
        ######
        # Apply sigmoid to logits to get probabilities
        probs = torch.sigmoid(recon_x[i])
        
        # Adjust target based on custom threshold (e.g., 0.3 instead of 0.5)
        thresholded_output = (probs > 0.5).float()
        ######
        # recon_loss += nn.BCEWithLogitsLoss(pos_weight=torch.tensor([10.0]).to(device))(recon_x[i], x_repeated)  # Loss for this sequence
        recon_loss += nn.BCEWithLogitsLoss()(recon_x[i], x_repeated)  # Loss for this sequence
        # recon_loss += focal_loss_fn(recon_x[i], x_repeated)
        
        correct, total = calculate_accuracy(recon_x[i], x_repeated)
        correct_total += correct
        cnt_total += total

    # KL-divergence loss: encourages latent variables to follow normal distribution
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    # Total loss is sum of reconstruction and KL-divergence loss
    total_loss = recon_loss + kld_loss
    return total_loss, recon_loss, kld_loss, correct_total, cnt_total

# 모델 데이터 변수 설정

In [5]:
train_loader, val_loader, test_loader = create_VAE(2)

reading original txt file...: 46188it [00:11, 4038.22it/s]


In [12]:
torch.cuda.empty_cache()
base_model = 'VAE'
# Usage example:
input_dim = 133    # Input dimension
encoder_hdim = 512   # Hidden layer size
decoder_hdim = 768
latent_dim = 128    # Latent space dimension
output_dim = 133   # Output dimension (sequence element size)
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
print(device)
model = VAE(input_dim, encoder_hdim, decoder_hdim, latent_dim, output_dim, device)
name = f'{base_model}_Focal_0.25_3'
model.load_state_dict(torch.load(f'/workspace/out/inst_vae/{name}/model_14_0.8351_0.8344.pt'))
model.eval()
model.to(device)


cuda:2


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

# 샘플 확인

In [57]:
torch.set_printoptions(profile="full")
cnt = 0
for chord, init, inst, length in tqdm(train_loader, ncols=60):
    cnt += 1
    if cnt == 16:
        print(chord)
        print(init)
        # print(inst)
        print(length)
        chord = chord.to(device)
        init = init.to(device)
        target = inst.to(device)
        
        recon_x, mu, logvar = model(init, chord, length)
        loss, recon_loss, kld_loss, correct, cnt = vae_loss_function(recon_x, target, mu, logvar, length)
        print(length)
        print(correct)
        print(cnt)
        print(correct/cnt)
        break

  0%|                    | 15/20785 [00:02<50:29,  6.86it/s]

tensor([[  2,  37,   6,  37,  62,  92,  36,  36,  92,  62, 134,  58, 134,  23,
          41, 114,  37, 114,  81,  81,  83,   6,  23,  26,  23,  26,  43,  37,
         114,  37,  83,  37,  37, 114,  62, 114, 114,  23,  37,  37, 114,  37,
         114,  37,   7,  93,  62,  37,  37,   7,  58, 107, 107,  59, 116,  59,
         116,  59,  80,   6,  58, 114,  38,  58, 114,  58, 114,  58, 114,  41,
          83,  61,  37,   6, 114,  58, 114,  58, 114,  41,  83,  61,  37,   6,
         114, 114, 114, 114, 114, 114, 114, 114, 114, 114,  23, 114, 114, 133,
         114, 133, 114, 133, 114, 114,  59, 116,  59,  63,   6,   6,  81,   6,
          81,   6, 114,  37,  26,  83,   6,  58, 114,  38, 107,  25,  83,   7,
          83, 133,  80,  80,  80,  10,   3,   3,   3,  98,  58,  58,  58,   7,
         113, 113, 113,  38,  37, 114,  37,  59,  23,  37,  61,  83,  92, 114,
          37, 114,  37, 114,  37,   7,  93,  62,  37,  37,   7,  58, 107,  58,
         113, 113,  36,  36,  92,  92, 114,  39, 114




In [58]:
probs = torch.sigmoid(recon_x[0])  # probs: [batch_size, max_length, num_classes]
    
# Convert probabilities to binary predictions
preds = (probs > 0.5).float()

print(preds)

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0

In [59]:
init_instruments = [0]*133

for measure in preds:
    for idx, inst in enumerate(measure):
        if inst == 1:
            init_instruments[idx] = 1
print(init_instruments)
print(list(init[0].cpu().numpy()))

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [45]:
init_instruments = [0]*133

for measure in preds:
    for idx, inst in enumerate(measure):
        if inst == 1:
            init_instruments[idx] = 1
print(init_instruments)
print(list(init[1].cpu().numpy()))

[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]


In [29]:
print(inst)

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
          0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
          0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0

In [None]:
pos_weight = torch.ones([64])
print(pos_weight.shape)

In [None]:
import torch

# Example tensors
tensor1 = torch.tensor([0, 1, 1, 0, 0, 1])
tensor2 = torch.tensor([0, 0, 1, 1, 0, 1])

# Calculate pairs
zero_zero = torch.sum((tensor1 == 0) & (tensor2 == 0)).item()
zero_one = torch.sum((tensor1 == 0) & (tensor2 == 1)).item()
one_zero = torch.sum((tensor1 == 1) & (tensor2 == 0)).item()
one_one = torch.sum((tensor1 == 1) & (tensor2 == 1)).item()

# Display the results
print(f"0-0: {zero_zero}")
print(f"0-1: {zero_one}")
print(f"1-0: {one_zero}")
print(f"1-1: {one_one}")


In [None]:
train_loader, val_loader, test_loader = create_dataloaders(64)


In [None]:
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
total_measure_len = 0
total_inst_sum = torch.zeros([133]).int().to(device)
torch.set_printoptions(profile="full")
for (chords, targets, lengths) in tqdm(test_loader, ncols=60):

    targets = torch.sum(targets.int().to(device), dim=(0,1)).to(device)
    total_inst_sum = targets + total_inst_sum
    total_measure_len += sum(lengths)
    # total_cnt, z_z, o_o, extra, lack = subset_accuracy(outputs, targets, lengths)
print(total_measure_len)
print(total_inst_sum)

In [None]:
for (chords, targets, lengths) in tqdm(val_loader, ncols=60):

    targets = torch.sum(targets.int().to(device), dim=(0,1)).to(device)
    total_inst_sum = targets + total_inst_sum
    total_measure_len += sum(lengths)
    # total_cnt, z_z, o_o, extra, lack = subset_accuracy(outputs, targets, lengths)
print(total_measure_len)
print(total_inst_sum)

In [None]:
import torch
insts = torch.tensor([     1,   4619,   4619,      1, 194689,   5963,   2904,   3006,   4043,
          1865,   5910,   1144,   4199,  46075,   2347,  30494,  26470,  20009,
         21011,   1575,   2001,   1447,   1673,   8121,   2218,   5168,   3859,
           589,  13026,  10995,   3762,  12184,   2909,   6010,   4936,    800,
         25684,  29864,  18695,   3947,   1035,   1270,   1680,   1725, 133476,
         83210, 112155, 110429,   6010,  35518,  39569, 123554, 209919,   5973,
          6197,   1437,  41111,   3282,   1407,    803, 206607, 194703, 184401,
          3298, 246582,   7145,   2799,   1283,  13321, 111674,  94783,  80731,
        202936,  19339, 194458, 263524,  72578, 274172,   5851,   5062,    262,
           433,   3253,   2275,   3255,   3337,   2120,    159,    176,    982,
           170,   2839,   1002,   2214,    785,   1160,    363,    239,    452,
           334,    696,    212,    259,    738,    527,    124,    301,    151,
           325,   2918,    121,    465,    474,    692,    592,     79,   1128,
            24,    720,    164,   1640,   1216,     15,    469,     20,     99,
           235,     76,     21,     96,    131,     28, 221543])

zeros = 599763 - insts
insts = insts.float()
pos = ((zeros/insts)).floor().int()
# print(zeros)
print(pos)
print(pos.shape)

In [None]:
for (chords, targets, lengths) in tqdm(train_loader, ncols=60):

    targets = torch.sum(targets.int().to(device), dim=(0,1)).to(device)
    total_inst_sum = targets + total_inst_sum
    total_measure_len += sum(lengths)
    # total_cnt, z_z, o_o, extra, lack = subset_accuracy(outputs, targets, lengths)
print(total_measure_len)
print(total_inst_sum)

In [None]:
import torch

insts = torch.tensor([      1,   46188,   46188,       1, 1965170,   52627,   19953,   27346,
          43178,   12233,   56133,    8498,   46038,  446219,   30396,  316394,
         299141,  201873,  204341,   14132,   17016,   16367,   17421,   88733,
          21916,   45745,   32097,    9057,  121791,  110299,   37992,  132749,
          27637,   47773,   45304,    5699,  251588,  315148,  180803,   28895,
          11680,    9634,   19766,   15624, 1326497,  871168, 1146786, 1073349,
          83884,  355458,  372425, 1268948, 2122342,   81723,   68123,   16416,
         407334,   26430,    9717,    7823, 2069709, 1954102, 1836015,   34275,
        2426194,   74003,   24758,   12064,  130216, 1108359,  931372,  782923,
        2007699,  192456, 1959928, 2628553,  738617, 2753066,   52856,   37749,
           4444,    4338,   26182,   21912,   35167,   34416,   17968,    2895,
           1462,    8413,    1562,   25468,   11404,   14087,    6955,   11182,
           4150,    3995,    5718,    4049,   11097,    2869,    3242,    7486,
           6618,    1661,    3252,    2992,    6783,   29266,    2395,    8733,
           4926,    7649,   10733,    1368,   12724,     401,    8195,    2366,
          13414,   11592,    1480,    4097,     260,     756,    2571,     278,
            218,     793,     471,    1569, 2227362])

zeros = 6084876 - insts
insts = insts.float()
pos = torch.ceil((zeros/insts)).int()
# print(zeros)
print(pos)
print(torch.median(pos.float()))
print(torch.mean(pos.float()))
print(pos.shape)

In [None]:
import torch

# Example tensor
float_tensor = torch.tensor([1.5, 2.3, 3.7])

# Convert to integer tensor with rounding up
int_tensor = torch.ceil(float_tensor).int()

print(int_tensor)
print(float_tensor[:-1])
