In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
import os
from MinecraftSequencePredict import MinecraftSequencePredict
import Dataset
from Dataset import MinecraftBlockData, custom_collate
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR

In [2]:
# -----------------------------
#      Hyperparameters
# -----------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vocab_size = 251
d_model = 768
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
dim_feed_forward = 2048 * 2

src_shape  = (5,5,5)
tgt_shape  = (1,5,5)
tgt_offset = (5,0,0)
# Model Definition
model = MinecraftSequencePredict(vocab_size, 
                                 d_model, 
                                 nhead, 
                                 num_encoder_layers, 
                                 num_decoder_layers, 
                                 dim_feed_forward,
                                 src_shape, 
                                 tgt_shape, 
                                 tgt_offset, 
                                 device).to(device)

In [3]:
# ------------------------
# Training Parameters
# ------------------------
num_epochs = 40
batch_size = 64
shuffle = True


learning_rate = 1e-5
gamma = 0.1
step_size = 1

eps = 1e-9
dropout = 0.1

num_workers = 4

n_data = 1e3

In [4]:
# ----------------------
# Dataset and DataLoader
# ----------------------

path = '../Datasets/Complete_Datasets/Minecraft6_5_5/data/'

files = Dataset.get_filenames(path, int(n_data))

train_filenames, test_filenames = train_test_split(files, test_size = 0.1, random_state=42)

train_data = MinecraftBlockData(path, train_filenames)

test_data = MinecraftBlockData(path, test_filenames)

training_dataloader = DataLoader(train_data, batch_size, shuffle, num_workers=num_workers, collate_fn=custom_collate)

test_dataloader = DataLoader(test_data, batch_size, shuffle, num_workers=num_workers, collate_fn=custom_collate)

In [5]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, eps=eps)
loss_fn = nn.CrossEntropyLoss()
scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)

epoch_losses = []
validation_losses = []

for epoch in range(num_epochs):
    model.train()
    batch_losses = []
    epoch_loss = 0
    validation_loss = 0
    k=0
    for i, data in enumerate(training_dataloader):
        k+=1
        src = data['src'].to(device)
        tgt = data['tgt'].to(device)
        
        output = model(src, tgt) # Batch Size x Tgt Sequence Length x Vocab Size
        
        # Batch Size x Tgt Sequence Length x Vocab Size --> (Batch Size * Tgt Sequence Length x Vocab Size)
        loss = loss_fn(output.view(-1, vocab_size), tgt.view(-1))
        epoch_loss += loss
        
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
    # for i, data in enumerate(test_dataloader):
    #     src = data['src'].to(device)
    #     tgt = data['tgt'].to(device)
        
    #     output = model(src, tgt) # Batch Size x Tgt Sequence Length x Vocab Size
        
    #     loss = loss_fn(output.view(-1, vocab_size), tgt.view(-1))
    #     validation_loss += loss
        
    epoch_losses.append(epoch_loss)
    validation_losses.append(validation_loss)
    
    print(f'\nEpoch {epoch + 1}, Learning Rate: {scheduler.get_last_lr()[0]}')
    print(f'Epoch {epoch + 1} Training loss: {epoch_loss:.2f}')
    # print(f'Epoch {epoch + 1} Validation loss: {validation_loss:.2f}\n')
    if(epoch > 0):
        print(f"Improvements: Train: {(epoch_losses[epoch-1] - epoch_losses[epoch]):.2f}|")
        #  Test: {(validation_losses[epoch-1] - validation_losses[epoch]):.2f}
    
    scheduler.step() 


Epoch 1, Learning Rate: 1e-05
Epoch 1 Training loss: 76.27

Epoch 2, Learning Rate: 5e-06
Epoch 2 Training loss: 74.09
Improvements: Train: 2.18|

Epoch 3, Learning Rate: 2.5e-06
Epoch 3 Training loss: 74.27
Improvements: Train: -0.18|

Epoch 4, Learning Rate: 1.25e-06
Epoch 4 Training loss: 74.27
Improvements: Train: -0.01|


KeyboardInterrupt: 

In [None]:
model.eval()
validation = model(src_example.to(device), tgt_sos.to(device))
model.parameters()

<generator object Module.parameters at 0x000002178B25DD90>

In [None]:
#create some sample data for a test forward pass
torch.manual_seed(42)

# batch_sizex5x5x5 tensor with random integers from 0 to vocab size
src_example = torch.randint(0, vocab_size - 1, size=src_shape, dtype=torch.int).view(-1).unsqueeze(0).expand(batch_size,-1).to(device)
tgt_example = torch.cat([torch.tensor([vocab_size-1]).view(-1),torch.randint(0, vocab_size - 1, size=tgt_shape, dtype=torch.int).view(-1)]).unsqueeze(0).expand(batch_size,-1).to(device)
tgt_sos = torch.tensor([vocab_size-1]).unsqueeze(0).expand(batch_size,-1).to(device)

output = model(src_example, tgt_sos)