In [1]:
import torch
import pickle
from torch import nn
from torch.utils.data import DataLoader

print("PyTorch version:")
print(torch.__version__)
print("GPU Detected:")
# print(torch.cuda.is_available())
print(torch.backends.mps.is_available())

# defining a shortcut function for later:
import os

# gpu = torch.device("cuda:0")
gpu = torch.device("mps")

PyTorch version:
2.2.0
GPU Detected:
True


In [2]:
with open("./data/train_merged", "rb") as f:
    train_text = pickle.load(f)
with open("./data/test_merged", "rb") as f:
    test_text = pickle.load(f)
    

batch_size=32
train_text_data = DataLoader(train_text, batch_size=batch_size, shuffle=True)
test_text_data = DataLoader(test_text, batch_size=batch_size, shuffle=True)

In [3]:
print(len(train_text_data.dataset))
print(len(test_text_data.dataset))

5770
1442


In [4]:
# check batch dimension
batch_size = train_text_data.batch_size
for data, label in train_text_data:
    print("shape: {0}".format(data.size()))
    break
sequence_length = data.size()[1]
print(batch_size)

# set batch first true. 

shape: torch.Size([32, 1001, 300])
32


In [5]:
class RNNClassifier(nn.Module):
    def __init__(
        self, hidden_size, input_size, num_layers, num_classes, dropout, activation_fn
    ):
        super(RNNClassifier, self).__init__()
        self.gru = nn.GRU(
            input_size, hidden_size, num_layers, dropout=dropout, batch_first=True
        )
        self.fc = nn.Linear(hidden_size * sequence_length, num_classes)
        self.nonlinearity = activation_fn

    def forward(self, x):
        # Initialize hidden state with random values.
        h0 = torch.randn(self.gru.num_layers, x.size(0), self.gru.hidden_size).to(gpu)

        out, _ = self.gru(x, h0)
        out = out.reshape(out.shape[0], -1)
        out = self.fc(out)

        out_distribution = nn.functional.log_softmax(out, dim=-1)
        return out_distribution


def train(
    train_dataloader, test_dataloader, nll_criterion, num_epochs, ffnn, ffnn_optimizer
):
    # A counter for the number of gradient updates we've performed.
    num_iter = 0

    # Iterate `num_epochs` times.
    for epoch in range(num_epochs):
        print("Starting epoch {}".format(epoch + 1))
        # Iterate over the train_dataloader, unpacking the images and labels
        for data, labels in train_dataloader:
            # If we're using the GPU, move reshaped_images and labels to the GPU.
            if gpu:
                data = data.to(gpu)
                labels = labels.to(gpu)

            # Run the forward pass through the model to get predicted log distribution.
            predicted = ffnn(data)

            # Calculate the loss
            batch_loss = nll_criterion(predicted, labels)

            # Clear the gradients as we prepare to backprop.
            ffnn_optimizer.zero_grad()

            # Backprop (backward pass), which calculates gradients.
            batch_loss.backward()

            # Take a gradient step to update parameters.
            ffnn_optimizer.step()

            # Increment gradient update counter.
            num_iter += 1

            # Calculate test set loss and accuracy every 100 gradient updates
            # It's standard to have this as a separate evaluate function, but
            # we'll place it inline for didactic purposes.
            if num_iter % 100 == 0:
                # Set model to eval mode, which turns off dropout.
                ffnn.eval()
                # Counters for the num of examples we get right / total num of examples.
                num_correct = 0
                total_examples = 0
                total_test_loss = 0

                with torch.no_grad():
                    # Iterate over the test dataloader
                    for test_data, test_labels in test_dataloader:

                        # If we're using the GPU, move tensors to the GPU.
                        if gpu:
                            test_data = test_data.to(gpu)
                            test_labels = test_labels.to(gpu)

                        # Run the forward pass to get predicted distribution.
                        predicted = ffnn(test_data)

                        # Calculate loss for this test batch. This is averaged, so multiply
                        # by the number of examples in batch to get a total.
                        total_test_loss += nll_criterion(
                            predicted, test_labels
                        ).data * test_labels.size(0)

                        # Get predicted labels (argmax)
                        _, predicted_labels = torch.max(predicted.data, 1)

                        # Count the number of examples in this batch
                        total_examples += test_labels.size(0)

                        # Count the total number of correctly predicted labels.
                        # predicted == labels generates a ByteTensor in indices where
                        # predicted and labels match, so we can sum to get the num correct.
                        num_correct += torch.sum(predicted_labels == test_labels.data)
                accuracy = 100 * num_correct / total_examples
                average_test_loss = total_test_loss / total_examples
                print(
                    "Iteration {}. Test Loss {}. Test Accuracy {}.".format(
                        num_iter, average_test_loss, accuracy
                    )
                )
                # Set the model back to train mode, which activates dropout again.
                ffnn.train()

In [6]:
activation_fn = nn.ReLU()
model = RNNClassifier(
    hidden_size=300, input_size=300, num_layers=1, num_classes=2, dropout=0.5, activation_fn=activation_fn
)
nll_criterion = nn.NLLLoss()
optimiser = torch.optim.Adam(model.parameters(),lr=0.00005)
model.to(gpu)

  from .autonotebook import tqdm as notebook_tqdm


RNNClassifier(
  (gru): GRU(300, 300, batch_first=True, dropout=0.5)
  (fc): Linear(in_features=300300, out_features=2, bias=True)
  (nonlinearity): ReLU()
)

In [7]:
num_epochs = 10
print(f"batch size:{batch_size}")
train(
    train_text_data, 
    test_text_data, 
    nll_criterion, 
    num_epochs, 
    model, 
    optimiser
)


batch size:32
Starting epoch 1
Iteration 100. Test Loss 0.2759212255477905. Test Accuracy 89.32038879394531.
Starting epoch 2
Iteration 200. Test Loss 0.20969322323799133. Test Accuracy 91.1234359741211.
Iteration 300. Test Loss 0.19367587566375732. Test Accuracy 92.02496337890625.
Starting epoch 3
Iteration 400. Test Loss 0.17116908729076385. Test Accuracy 93.34258270263672.
Iteration 500. Test Loss 0.1728503704071045. Test Accuracy 93.13453674316406.
Starting epoch 4
Iteration 600. Test Loss 0.1691211313009262. Test Accuracy 93.41192626953125.
Iteration 700. Test Loss 0.16025646030902863. Test Accuracy 93.48127746582031.
Starting epoch 5
Iteration 800. Test Loss 0.17392055690288544. Test Accuracy 93.55062103271484.
Iteration 900. Test Loss 0.1614162176847458. Test Accuracy 93.7586669921875.
Starting epoch 6
Iteration 1000. Test Loss 0.1685205101966858. Test Accuracy 93.41192626953125.
Starting epoch 7
Iteration 1100. Test Loss 0.17373643815517426. Test Accuracy 93.6199722290039.
Iter

Iteration 1800. Test Loss 0.17262473702430725. Test Accuracy 93.48127746582031.
batch size 32 lr 0.00001 drop out 0.5 inner nodes 256

Iteration 1800. Test Loss 0.2566724419593811. Test Accuracy 95.21498107910156.
batch size 32 lr 0.0005 drop out 0.5 inner nodes 256

Iteration 1800. Test Loss 0.2761397361755371. Test Accuracy 94.86824035644531.
batch size 32 lr 0.0001 drop out 0.4 inner nodes 300

Iteration 1800. Test Loss 0.2023600935935974. Test Accuracy 95.2843246459961.
batch size 32 lr 0.0001 drop out 0.5 inner nodes 300

Iteration 1800. Test Loss 0.16241158545017242. Test Accuracy 95.49237060546875.
batch size 32 lr 0.00005 drop out 0.5 inner nodes 300



915 mins 90% accuracy 5 epochs no batch.

Some issue with the loss function.
```
Starting epoch 1
Iteration 500. Test Loss 1.7329039573669434. Test Accuracy 77.62973022460938.
Iteration 1000. Test Loss 2.110083818435669. Test Accuracy 84.50210571289062.
Iteration 1500. Test Loss 1.8402926921844482. Test Accuracy 85.0631103515625.
Iteration 2000. Test Loss 3.071587562561035. Test Accuracy 84.2917251586914.
Iteration 2500. Test Loss 2.9173219203948975. Test Accuracy 86.81626892089844.
Iteration 3000. Test Loss 4.660653591156006. Test Accuracy 84.71248626708984.
Iteration 3500. Test Loss 3.0766427516937256. Test Accuracy 85.0631103515625.
Iteration 4000. Test Loss 2.9936556816101074. Test Accuracy 86.9565200805664.
Iteration 4500. Test Loss 4.268098831176758. Test Accuracy 86.11500549316406.
Iteration 5000. Test Loss 3.2028002738952637. Test Accuracy 88.07854461669922.
Iteration 5500. Test Loss 3.206970691680908. Test Accuracy 87.8681640625.
Starting epoch 2
Iteration 6000. Test Loss 3.411480665206909. Test Accuracy 88.56942749023438.
Iteration 6500. Test Loss 3.215914249420166. Test Accuracy 87.37728118896484.
Iteration 7000. Test Loss 3.1972854137420654. Test Accuracy 88.70967864990234.
Iteration 7500. Test Loss 4.611752510070801. Test Accuracy 84.57222747802734.
Iteration 8000. Test Loss 3.449533462524414. Test Accuracy 88.2889175415039.
Iteration 8500. Test Loss 4.394710063934326. Test Accuracy 89.34081268310547.
Iteration 9000. Test Loss 3.9236555099487305. Test Accuracy 89.34081268310547.
Iteration 9500. Test Loss 3.7748982906341553. Test Accuracy 88.56942749023438.
Iteration 10000. Test Loss 4.514174461364746. Test Accuracy 86.9565200805664.
Iteration 10500. Test Loss 3.833115339279175. Test Accuracy 87.44740295410156.
Iteration 11000. Test Loss 4.04207181930542. Test Accuracy 87.51753234863281.
Starting epoch 3
Iteration 11500. Test Loss 5.0256757736206055. Test Accuracy 88.70967864990234.
Iteration 12000. Test Loss 4.844799518585205. Test Accuracy 87.93828582763672.
Iteration 12500. Test Loss 4.611722946166992. Test Accuracy 89.06031036376953.
Iteration 13000. Test Loss 6.31632661819458. Test Accuracy 87.3071517944336.
Iteration 13500. Test Loss 6.104625225067139. Test Accuracy 87.72791290283203.
Iteration 14000. Test Loss 5.693385601043701. Test Accuracy 88.00841522216797.
Iteration 14500. Test Loss 5.045144557952881. Test Accuracy 89.34081268310547.
Iteration 15000. Test Loss 5.350050449371338. Test Accuracy 88.6395492553711.
Iteration 15500. Test Loss 5.098731517791748. Test Accuracy 88.56942749023438.
Iteration 16000. Test Loss 5.727728366851807. Test Accuracy 85.9747543334961.
Iteration 16500. Test Loss 4.773247241973877. Test Accuracy 89.83169555664062.
Iteration 17000. Test Loss 4.9621405601501465. Test Accuracy 89.55119323730469.
Starting epoch 4
Iteration 17500. Test Loss 6.257537841796875. Test Accuracy 88.35904693603516.
Iteration 18000. Test Loss 6.823083400726318. Test Accuracy 87.02664947509766.
Iteration 18500. Test Loss 6.571444034576416. Test Accuracy 88.07854461669922.
Iteration 19000. Test Loss 6.781907558441162. Test Accuracy 88.49929809570312.
Iteration 19500. Test Loss 6.27226448059082. Test Accuracy 89.06031036376953.
Iteration 20000. Test Loss 6.661423206329346. Test Accuracy 88.21879577636719.
Iteration 20500. Test Loss 6.769895076751709. Test Accuracy 89.69144439697266.
Iteration 21000. Test Loss 5.5185651779174805. Test Accuracy 90.18232727050781.
Iteration 21500. Test Loss 7.558089256286621. Test Accuracy 89.06031036376953.
Iteration 22000. Test Loss 6.829961776733398. Test Accuracy 89.55119323730469.
Iteration 22500. Test Loss 5.848672389984131. Test Accuracy 90.32257843017578.
Starting epoch 5
Iteration 23000. Test Loss 5.5816650390625. Test Accuracy 89.55119323730469.
Iteration 23500. Test Loss 7.699804782867432. Test Accuracy 87.3071517944336.
Iteration 24000. Test Loss 6.057502746582031. Test Accuracy 89.27069091796875.
Iteration 24500. Test Loss 6.385693550109863. Test Accuracy 89.55119323730469.
Iteration 25000. Test Loss 6.076298713684082. Test Accuracy 90.04207611083984.
Iteration 25500. Test Loss 6.222282409667969. Test Accuracy 89.55119323730469.
Iteration 26000. Test Loss 6.031567096710205. Test Accuracy 89.90182495117188.
Iteration 26500. Test Loss 7.921655654907227. Test Accuracy 90.25245666503906.
Iteration 27000. Test Loss 7.22239875793457. Test Accuracy 90.32257843017578.
Iteration 27500. Test Loss 7.10892391204834. Test Accuracy 90.532958984375.
Iteration 28000. Test Loss 6.968836784362793. Test Accuracy 89.55119323730469.
Iteration 28500. Test Loss 6.57443904876709. Test Accuracy 90.18232727050781.
```
915 mins 90% accuracy