In [7]:
import torch
import torch.nn as nn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

class ChordLSTM(nn.Module):
    def __init__(self, vocab_size=140, embedding_dim=2048, hidden_dim=512, num_layers=5):
        super(ChordLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x)
        output, (hidden, cell) = self.lstm(embedded)
        logits = self.fc(output)
        return logits
    
    def infer(self, input_ids, length=2048, train=False):
        if len(input_ids.shape) == 1:
            input_ids = input_ids.unsqueeze(0)
        if len(input_ids.shape) > 2:
            raise Exception
        
        if length > 2048:
            print("Max Length is 2048. Change Length Auto to 2048")
            length = 2048
        
        with torch.no_grad():
            for step in range(length):
                output = model(input_ids)
                output = torch.argmax(output, dim=2)

                predict = output[:,-1].unsqueeze(1)
                output_ids = torch.cat((input_ids, predict), dim=-1)

                input_ids = output_ids
                
                if torch.all(predict.eq(0)):
                    break
                
                if output_ids.shape[1] > 2048:
                    break

        return output_ids
    
model = ChordLSTM(vocab_size=140, embedding_dim=512, hidden_dim=512, num_layers=5).to(device)
model.load_state_dict(torch.load('./model_25_0.002685.pt', map_location=device))
model.eval()

cuda:0


ChordLSTM(
  (embedding): Embedding(140, 512)
  (lstm): LSTM(512, 512, num_layers=5, batch_first=True)
  (fc): Linear(in_features=512, out_features=140, bias=True)
)

In [11]:
import torch

loaded_data = torch.load('/workspace/data/tensor/chord_tensor.pt')

test = loaded_data[3].to(device).long()
print(test[:30])
print(test[:20])

out = model.infer(test[:20], length=10)
print(out)
print(out.shape)


tensor([135,  28,  28,  83,  84,  37,  37, 114, 114,  29,  29,  83,  84,  37,
         37, 114, 114,  29,  59,  83,  83,  37,  37, 114, 114,  29,  29,  83,
         83,  37], device='cuda:0')
tensor([135,  28,  28,  83,  84,  37,  37, 114, 114,  29,  29,  83,  84,  37,
         37, 114, 114,  29,  59,  83], device='cuda:0')
tensor([[135,  28,  28,  83,  84,  37,  37, 114, 114,  29,  29,  83,  84,  37,
          37, 114, 114,  29,  59,  83,  83,  37,  37, 114, 114,  29,  59,  83,
          83,  37]], device='cuda:0')
torch.Size([1, 30])
