#Create the models

In [2]:
# Create the baseline model
class BiLSTMTextClassifierModel(nn.Module):

  def __init__(self, vocab, embedding_dim, hidden_dim, number_of_labels):
    super(BiLSTMTextClassifierModel, self).__init__()
    self.number_of_labels = number_of_labels
    self.embedding = nn.Embedding(len(vocab), embedding_dim, vocab["<pad>"])
    self.rnn = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, batch_first=True)
    self.top_layer = nn.Linear(2*hidden_dim, self.number_of_labels)
    self.relu = nn.ReLU()
    self.softmax = F.softmax

  def forward(self, x):
    embeddings = self.embedding(x)
    rnn_output, _ = self.rnn(embeddings)
    last_hidden = rnn_output[:, -1, :]
    top_layer_output = self.top_layer(self.relu(last_hidden))
    return self.softmax(self.relu(top_layer_output), dim=-1)



ModuleNotFoundError: No module named 'models'

In [14]:
# trains one batch, returns total batch loss
def train_one_batch(model, inputs, targets, optimizer, loss_function):
        # Predict/Forward Pass
        predictions = model(inputs)
        # Compute loss
        loss = loss_function(predictions, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Multiply the cross entropy loss which is the average by the batch size so we get the total loss for the batch, we can divide this by all data set 
        # Size to get average loss for the epoch
        batch_size_train =  len(inputs)
        batch_loss = loss.item() * batch_size_train
        return batch_loss
    

In [15]:
# validates one batch, returns total batch loss, number of true positives
def validate_one_batch(model, inputs, targets, loss_function):
        predictions_val = model(inputs).detach()
        loss_validation = loss_function(predictions_val, targets)
        
        
        # calculate average loss
        batch_size_val = len(inputs)
        batch_loss = loss_validation.item() * batch_size_val
        
        # Calculate True positives
        predicted_class = predictions_val.argmax(axis=1)
        correct_class = targets.argmax(axis=1)
    
        true_positives_count = sum(predicted_class == correct_class).item()
        return batch_loss, true_positives_count

In [16]:
# Train model
model = BiLSTMTextClassifierModel(liar_vocab, 300, 128, 6)
# Training params
num_epochs = 50
# Hyper parameters
learning_rate = 0.001

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
device = torch.device("cuda:6")
model.to(device)
loss_fn.to(device)

train_loss = []
val_loss = []
val_accuracy_history = []
model.train()
for epoch in range(num_epochs):
    start_time = time.time()
    # Training
    model.train()
    total_train_loss = 0. 
    total_data_points_train = 0
    for train_batch in train_loader:
        inputs = train_batch[0].to(device)
        targets = train_batch[1].to(device)
        
        train_batch_loss = train_one_batch(model, inputs, targets, optimizer, loss_fn)
        
        total_train_loss += train_batch_loss
        total_data_points_train += len(inputs)
    
    
    average_epoch_loss_train = total_train_loss/total_data_points_train
    train_loss.append(average_epoch_loss_train)
    
    # Validation
    # TODO restructure
    model.eval()        
    total_val_loss = 0. 
    total_data_points_val = 0
    true_positives_val = 0
    for validation_batch in validation_loader:
        inputs_val = validation_batch[0].to(device)
        targets_val = validation_batch[1].to(device)
        
        # Returns loss and true positives count
        batch_loss_val, true_positives_count = validate_one_batch(model, inputs_val, targets_val, loss_fn)
            
        # calculate average loss and appends true positives count
        total_val_loss += batch_loss_val
        total_data_points_val += len(inputs_val)
        true_positives_val += true_positives_count
    
    average_epoch_loss_val = total_val_loss/total_data_points_val
    val_accuracy = true_positives_val/total_data_points_val
    val_accuracy_history.append(val_accuracy)    
    val_loss.append(average_epoch_loss_val)
    
        
    # Print every epoch's metrics
    elapsed_time = time.time() - start_time
    print(f"epoch {epoch + 1}, average train loss: {average_epoch_loss_train}, average val loss: {average_epoch_loss_val}, val accuracy: {val_accuracy},  training time : {elapsed_time}")
    
      

epoch 1, average train loss: 1.7836974520236253, average val loss: 1.7813379786838994, val accuracy: 0.21573208722741433,  training time : 2.9156482219696045
epoch 2, average train loss: 1.7834785602986813, average val loss: 1.789648190094303, val accuracy: 0.19470404984423675,  training time : 2.095219373703003
epoch 3, average train loss: 1.7815971709787846, average val loss: 1.7851098293084593, val accuracy: 0.205607476635514,  training time : 1.9737038612365723
epoch 4, average train loss: 1.7747184906154871, average val loss: 1.7953519204695276, val accuracy: 0.220404984423676,  training time : 1.861588716506958
epoch 5, average train loss: 1.7767615742981433, average val loss: 1.7919336234297707, val accuracy: 0.1923676012461059,  training time : 2.1257686614990234
epoch 6, average train loss: 1.762551800906658, average val loss: 1.785548113588232, val accuracy: 0.21261682242990654,  training time : 2.2622108459472656
epoch 7, average train loss: 1.7333800189197064, average val l

In [17]:
# Test the model
model.eval()        
total_test_loss = 0. 
total_data_points_test = 0
true_positives_test = 0
i=0
for test_batch in test_loader:
    inputs_test = test_batch[0].to(device)
    targets_test = test_batch[1].to(device)
    # Returns loss and true positives count
    batch_loss_test, true_positives_count = validate_one_batch(model, inputs_test, targets_test, loss_fn)
            
    # calculate average loss and appends true positives count
    total_test_loss += batch_loss_test
    total_data_points_test += len(inputs_test)
    true_positives_test += true_positives_count
    
average_epoch_loss_test = total_test_loss/total_data_points_test
accuracy_test = true_positives_test/total_data_points_test
print(f"average test loss: {average_epoch_loss_test}, accuracy: {accuracy_test}")
    


average test loss: 1.8184774796983438, accuracy: 0.1499605367008682
