# Imports

In [1]:
from torch.utils.data import DataLoader, random_split
import torch
from data import ParenthesizationDataset, ParenthesizationModel
from train import train_one_epoch, evaluate_model
import matplotlib.pyplot as plt

# Initialization
Set the parameters here for training and initialize the train/test datasets, data loaders, model, loss function and optimizer.

In [2]:
n = 7
epochs = 50
train_split, test_split = 0.8, 0.2
learning_rate = 0.001
momentum = 0.9
batch_size = 32

training_dataset, test_dataset = random_split(ParenthesizationDataset(n), [train_split, test_split])
model = ParenthesizationModel(n)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
training_loader = DataLoader(training_dataset, batch_size=32, shuffle=True)

# Training the model
* Call `train_one_epoch` repeatedly to train the model once over the entire training dataset.
* Print out the epoch number and loss after each training call to verify that the loss is going down.
* Record the loss values in a list so that it can be plotted in the cell below.
* After training, call `evaluate_model` to get the confusion matrix.
* Save the model as `models/linear_model_{n}.pt`.

In [3]:
loss = []
for epoch in range(epochs):
    model.train(True)
    current_loss = train_one_epoch(training_loader, model, loss_fn, optimizer)
    print(f"Epoch {epoch}, Loss {current_loss}")
    loss.append(current_loss)
confusion_matrix = evaluate_model(model, test_dataset)
torch.save(model.state_dict(), "models/linear_model.pt")

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'NoneType'>

# Plot the loss curve
Plot the loss curve with appropriate figure title and axis labels. Save the resulting figure in `figures/loss_curve_{n}.png`.

In [None]:
# Plot loss function
plt.plot(loss)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title(f"Loss Curve for n={n}")
plt.show()
plt.savefig(f"figures/loss_curve_{n}.png")

# Analyze accuracy using the confusion matrix
- Print out the accuracy within each of the prediction classes.
- Print out the size of each prediction class.
- Print out the overall accuracy.

In [None]:
# Plot confusion matrix
print(confusion_matrix)
for i in range(2):
    print(f"Accuracy for class {i} = {confusion_matrix[i][i]/sum(confusion_matrix[i])*100}%")
print(f"Overall accuracy = {sum([confusion_matrix[i][i] for i in range(2)])/sum([confusion_matrix[i][0]+confusion_matrix[i][1] for i in range(2)])*100}%")

# Plot the model weights
Can you interpret what the model is doing? Plot the model weights using `plt.imshow()` to get a heatmap. Choose a colormap from https://matplotlib.org/stable/users/explain/colors/colormaps.html that you prefer. I default to the `bwr` colormap where negative values are blue, positive value are red, and values close to zero are white.

In [None]:
# Plot heatmap of model weights
class0, class1 = model.fc.weight[0], model.fc.weight[1]
plt.imshow(class0.unsqueeze(0).detach().numpy(), cmap="bwr")
plt.show()
plt.imshow(class1.unsqueeze(0).detach().numpy(), cmap="bwr")
plt.show()

# "Translate" the model into code.
Implement `simple_evaluate` which condenses the model's "logic" into a single if-else statement. Run this evaluation function over the test set to produce a new confusion matrix and see how it performs compared to the model you train.

In [None]:
# Try to condense the linear model down into a single if-else statement. 
def simple_evaluate(input):
    if input[0] < input[1] or input[-1] < input[-2]:
        return torch.tensor([1, 0])
    else:
        return torch.tensor([0, 1])

confusion_matrix = [[0, 0], [0, 0]]
for data in test_dataset:
    input, label = data
    output = simple_evaluate(input)
    confusion_matrix[label][torch.argmax(output)] += 1
for i in range(2):
    print(f"Accuracy for class {i} = {confusion_matrix[i][i]/sum(confusion_matrix[i])*100}%")
    print(f"Overall accuracy = {sum([confusion_matrix[i][i] for i in range(2)])/sum([confusion_matrix[i][0]+confusion_matrix[i][1] for i in range(2)])*100}%")