In [1]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from MinecraftSequencePredict import MinecraftSequencePredict

In [2]:
# -----------------------------
#      Hyperparameters
# -----------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vocab_size = 250
d_model = 768
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6

src_shape  = (5,5,5)
tgt_shape  = (1,5,5)
tgt_offset = (5,0,0)
# Model Definition
model = MinecraftSequencePredict(vocab_size, 
                                 d_model, 
                                 nhead, 
                                 num_encoder_layers, 
                                 num_decoder_layers, 
                                 src_shape, 
                                 tgt_shape, 
                                 tgt_offset, 
                                 device)

In [3]:
# Training Parameters
num_epochs = 50
batch_size = 32

learning_rate = 5e-4
eps = 1e-9

dropout = 0.1

In [4]:
#create some sample data for a test forward pass
torch.manual_seed(42)

# batch_sizex5x5x5 tensor with random integers from 0 to vocab size
src_example = torch.randint(0, vocab_size - 1, size=src_shape, dtype=torch.int).view(-1).unsqueeze(0).expand(batch_size,-1).to(device)
tgt_example = torch.cat([torch.tensor([vocab_size-1]).view(-1),torch.randint(0, vocab_size - 1, size=tgt_shape, dtype=torch.int).view(-1)]).unsqueeze(0).expand(batch_size,-1).to(device)
tgt_sos = torch.tensor([vocab_size-1]).unsqueeze(0).expand(batch_size,-1).to(device)

output = model(src_example, tgt_sos)

In [5]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, eps=eps)
loss_fn = nn.CrossEntropyLoss()

batch_losses = []
epoch_losses = []

# this is just an example until a Dataloader is implemented
num_batches = 10 
for epoch in range(num_epochs):
    model.train()
    batch_losses = []

    for i in range(num_batches):
        training_src_batch = src_example.to(device) # Batch Size x Src Sequence Length
        training_tgt_batch = tgt_example.to(device) # Batch Size x Tgt Sequence Length
        
        output = model(training_src_batch, training_tgt_batch) # Batch Size x Tgt Sequence Length x Vocab Size
        
        # Batch Size x Tgt Sequence Length x Vocab Size --> (Batch Size * Tgt Sequence Length x Vocab Size)
        loss = loss_fn(output.view(-1, vocab_size), training_tgt_batch.view(-1))
        batch_losses.append(loss)
        
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        print(f'batch loss: {batch_losses[len(batch_losses)-1]}')
    
    epoch_losses.append(sum(batch_losses))
    
    print(f'epoch {epoch + 1} loss: {epoch_losses[len(epoch_losses)-1]}')

batch loss: 5.521764755249023
batch loss: 5.510337829589844
batch loss: 5.496887683868408
batch loss: 5.488706588745117
batch loss: 5.484743595123291
batch loss: 5.484691143035889
batch loss: 5.483263969421387
batch loss: 5.461836814880371
batch loss: 5.453036308288574
batch loss: 5.451832294464111
epoch 1 loss: 54.837100982666016
batch loss: 5.448557376861572
batch loss: 5.445276737213135
batch loss: 5.4478936195373535
batch loss: 5.44763708114624
batch loss: 5.437561988830566
batch loss: 5.433464527130127
batch loss: 5.398129940032959
batch loss: 5.397557258605957
batch loss: 5.355498313903809
batch loss: 5.347979545593262
epoch 2 loss: 54.1595573425293
batch loss: 5.335546493530273
batch loss: 5.319689750671387
batch loss: 5.311722755432129
batch loss: 5.287668228149414
batch loss: 5.26289701461792
batch loss: 5.235231876373291
batch loss: 5.1971259117126465
batch loss: 5.134206771850586
batch loss: 5.086359977722168
batch loss: 5.053767681121826
epoch 3 loss: 52.224220275878906
bat

In [7]:
model.eval()
validation = model(src_example.to(device), tgt_sos.to(device))
model.parameters()

<generator object Module.parameters at 0x000002178B25DD90>