In this notebook I will just be testing the code that will go into the `script.py` - it is purely for experimental reasons, and it will not be needed elsewhere

Basically, in the `script.py` we will do the following:
1. Import necessary libraries & get data
2. Encode the categories
3. Prepare the Tokenization for the input text (i.e., the news headlines in our case)
4. Create the PyTorch Dataset class - this is where we will tokenize the text data
5. Split the data into training & testing
6. Create the training/testing PyTorch DataLoaders
7. Create the model class & initialize a model
8. Select loss function, optimizer, accuracy metrics - this is just for clarity of process, we will actually select them in step 10
9. Create training loop
10. Set up the `main()` function which will start everything

# 1. Import libraries & modules, and get data

In [2]:
!pip install s3fs



In [3]:
!pip install torchinfo



In [4]:
import torch
import transformers
import pandas as pd
import numpy as np
import os
import argparse
import torchmetrics

from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchinfo import summary
from transformers import DistilBertTokenizer, DistilBertModel
from tqdm.auto import tqdm
from typing import List, Dict, Tuple
from timeit import default_timer as timer

2025-01-18 07:18:16.163813: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
# Let's get data
s3_path = 's3://tk5-huggingface-multiclass-textclassification-bucket/training_data/newsCorpora.csv'

df = pd.read_csv(s3_path, 
                 sep='\t', 
                 names=['ID', 'TITLE', 'URL', 'PUBLISHER', 'CATEGORY', 'STORY', 'HOSTNAME', 'TIMESTAMP'])

print(f"First few rows of the dataframe:\n{df[['TITLE', 'CATEGORY']].head()}")

First few rows of the dataframe:
                                               TITLE CATEGORY
0  Fed official says weak data caused by weather,...        b
1  Fed's Charles Plosser sees high bar for change...        b
2  US open: Stocks fall after Fed official hints ...        b
3  Fed risks falling 'behind the curve', Charles ...        b
4  Fed's Plosser: Nasty Weather Has Curbed Job Gr...        b


In [6]:
# Let's update the CATEGORY variable as we did before

# Create the mapping through which we will update the CATEGORY variable
my_dict = {
    'b': 'BUSINESS',
    't': 'SCIENCE',
    'e': 'ENTERTAINMENT',
    'm': 'HEALTH'
}

# Create helper function
def update_category(x, dictionary: dict):
    return dictionary.get(x)

# Update the CATEGORY variable
df['CATEGORY'] = df['CATEGORY'].apply(lambda x: update_category(x, dictionary=my_dict))

In [7]:
df = df[['TITLE', 'CATEGORY']]
df.head()

Unnamed: 0,TITLE,CATEGORY
0,"Fed official says weak data caused by weather,...",BUSINESS
1,Fed's Charles Plosser sees high bar for change...,BUSINESS
2,US open: Stocks fall after Fed official hints ...,BUSINESS
3,"Fed risks falling 'behind the curve', Charles ...",BUSINESS
4,Fed's Plosser: Nasty Weather Has Curbed Job Gr...,BUSINESS


In [8]:
len(df)

422419

Also, we will only train the model on a fraction of the entire data. The reason is that first I want to make sure that everything is running as intended, before we use the entire data and incur higher costs. I don't want to find out that there's something wrong with my code after having used a GPU instance for hours and hours!

In [9]:
df = df.sample(frac=0.10, random_state=1) # selecting only 10% of the data
df = df.reset_index(drop=True)

In [10]:
len(df)

42242

In [11]:
df.head()

Unnamed: 0,TITLE,CATEGORY
0,Murdoch's bid for Time Warner rejected,BUSINESS
1,Rescuers close in on 3 trapped Honduran miners...,BUSINESS
2,Johnny Depp - Johnny Depp Served With Legal Pa...,ENTERTAINMENT
3,"Apple prepping move into ""smart home"" connecti...",SCIENCE
4,Ripped First Look: Dwayne Johnson as Brett Rat...,ENTERTAINMENT


In [12]:
print(f"Count by category: {df.groupby(['CATEGORY']).count()}")

Count by category:                TITLE
CATEGORY            
BUSINESS       11438
ENTERTAINMENT  15275
HEALTH          4566
SCIENCE        10963


# 2. Encode the categories

In [13]:
df['ENCODE_CAT'] = df.groupby(by=['CATEGORY']).ngroup() # the .ngroup() assigns a number to each unique category
df.head()

Unnamed: 0,TITLE,CATEGORY,ENCODE_CAT
0,Murdoch's bid for Time Warner rejected,BUSINESS,0
1,Rescuers close in on 3 trapped Honduran miners...,BUSINESS,0
2,Johnny Depp - Johnny Depp Served With Legal Pa...,ENTERTAINMENT,1
3,"Apple prepping move into ""smart home"" connecti...",SCIENCE,3
4,Ripped First Look: Dwayne Johnson as Brett Rat...,ENTERTAINMENT,1


In [14]:
# Quickly validate that our code is producing outputs as intended
df.drop_duplicates(subset=['CATEGORY', 'ENCODE_CAT'])

Unnamed: 0,TITLE,CATEGORY,ENCODE_CAT
0,Murdoch's bid for Time Warner rejected,BUSINESS,0
2,Johnny Depp - Johnny Depp Served With Legal Pa...,ENTERTAINMENT,1
3,"Apple prepping move into ""smart home"" connecti...",SCIENCE,3
15,Some Mangos Sold In New Jersey Recalled,HEALTH,2


In [19]:
small_sample_df = df.drop_duplicates(subset=['CATEGORY', 'ENCODE_CAT'])
small_sample_df = small_sample_df.sort_values(by=['ENCODE_CAT']).reset_index(drop=True)
small_sample_df

Unnamed: 0,TITLE,CATEGORY,ENCODE_CAT
0,Murdoch's bid for Time Warner rejected,BUSINESS,0
1,Johnny Depp - Johnny Depp Served With Legal Pa...,ENTERTAINMENT,1
2,Some Mangos Sold In New Jersey Recalled,HEALTH,2
3,"Apple prepping move into ""smart home"" connecti...",SCIENCE,3


In [24]:
categories = list(small_sample_df['CATEGORY'])
encoded_cats = list(small_sample_df['ENCODE_CAT'])
categories, encoded_cats

class_to_idx = {}
i = 0
for category in categories:
    class_to_idx[f'{category}'] = encoded_cats[i]
    i += 1

class_to_idx, categories, encoded_cats

({'BUSINESS': 0, 'ENTERTAINMENT': 1, 'HEALTH': 2, 'SCIENCE': 3},
 ['BUSINESS', 'ENTERTAINMENT', 'HEALTH', 'SCIENCE'],
 [0, 1, 2, 3])

# 3. Prepare the Text Tokenization

In [14]:
from transformers import DistilBertTokenizer

# Get the tokenizer of choice
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')



In [15]:
# Tokenize input data - just an example!
inputs = tokenizer.encode_plus(
    "I love football and horse-riding",
    "I went to the University of Toronto",
    add_special_tokens=True, # adds the ['CLS'] and ['SEP'] tokens
    max_length=20,
    padding='max_length',
    truncation=True,
    return_token_type_ids=True, # returns 0s for the first sentence, 1s for the second sentence and so on. Returns 0s when padding begins. If input is only one sentence & padding, it's all 0s 
    return_attention_mask=True # Specifies where the model should pay attention (1s) and where padding begins (0s)
)

inputs

{'input_ids': [101, 1045, 2293, 2374, 1998, 3586, 1011, 5559, 102, 1045, 2253, 2000, 1996, 2118, 1997, 4361, 102, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]}

In [16]:
print(f"Input IDs: {inputs['input_ids']}")
print(f"Token Type IDs: {inputs['token_type_ids']}")
print(f"Attention Mask: {inputs['attention_mask']}")

Input IDs: [101, 1045, 2293, 2374, 1998, 3586, 1011, 5559, 102, 1045, 2253, 2000, 1996, 2118, 1997, 4361, 102, 0, 0, 0]
Token Type IDs: [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]
Attention Mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]


# 4. Create the PyTorch Dataset class

Here, we will prepare the data such that the model can learn from it - remember, the model cannot process text data directly, first we need to tokenize it (i.e., convert it into numbers)

In [49]:
class NewsDataset(Dataset):

    # Define the __init__() method:
    def __init__(self, data, tokenizer, max_length):
        super().__init__()
        
        # 1. Initialize the data, tokenizer & allocated maximum length for the model inputs (remember, our task deals with news headlines, so will choose max_length accordingly)
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.len = len(data)

    # Define the __getitem__ method:
    def __getitem__(self, index):
        # 1. Get headline from the source dataframe
        # headline = self.data.TITLE[index]
        headline = str(self.data.iloc[index, 0]) # more efficient than the line above
        headline = " ".join(headline.split())

        # 2. Tokenize the headline
        headline_tokenized = self.tokenizer.encode_plus(
            
            headline, # input text
            add_special_tokens = True,
            max_length = self.max_length,
            padding = 'max_length',
            truncation = True,
            return_attention_mask = True,
            return_token_type_ids = True,
        
        )

        ids = headline_tokenized['input_ids']
        mask = headline_tokenized['attention_mask']
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'targets': torch.tensor(self.data.iloc[index, 2], dtype=torch.long)
        }

    # Define the __len__() method:
    def __len__(self):
        return self.len

# 5. Split the data into training and testing

In [50]:
import sklearn
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, shuffle=True, stratify=df['CATEGORY'])

In [51]:
print(f"Training data has shape: {train_df.shape}")
print(f"Testing data has shape: {test_df.shape}")

Training data has shape: (33793, 3)
Testing data has shape: (8449, 3)


# 6. Create the Training/Testing PyTorch DataLoaders

In [52]:
MAX_LEN = 512
TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 2

In [53]:
train_dataset = NewsDataset(data=train_df, tokenizer=tokenizer, max_length=MAX_LEN)
len(train_dataset)

33793

In [54]:
test_dataset = NewsDataset(data=test_df, tokenizer=tokenizer, max_length=MAX_LEN)
len(test_dataset)

8449

In [55]:
# Train DataLoader
train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size = TRAIN_BATCH_SIZE,
                              shuffle=True, 
                              num_workers=0)

In [56]:
# Test DataLoader
test_dataloader = DataLoader(dataset=test_dataset, 
                             batch_size=VALID_BATCH_SIZE, 
                             shuffle=False, 
                             num_workers=0)

# 7. Create the model class

Here we will adapt the DistilBERT architecture to our own specific problem - we will be adding additional layers to it

In [57]:
class FT_DistilBERT(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.block_1 = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.layer_2 = nn.Linear(in_features=768,
                                 out_features=768)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(p=0.3)
        self.classifier_layer = nn.Linear(in_features=768,
                                          out_features=num_classes)


    def forward(self, input_ids, mask_ids):
        # 1. Send through the DistilBERT pre-trained model
        output = self.block_1(input_ids = input_ids,
                              attention_mask = mask_ids)
        hidden_state = output[0]
        pooler = hidden_state[:, 0]
        
        # 2. Send through the linear layer - this serves to increase the representational capacity of our model
        output = self.layer_2(pooler)
        # 3. Send through a non-linear activation function
        output = self.activation(output)
        # 4. Apply dropout to fight over-fitting
        output = self.dropout(output)
        # 5. Get the classification prediction (in logits)
        output = self.classifier_layer(output)

        return output

In [26]:
model = FT_DistilBERT(num_classes=4)



In [27]:
# Let's view a summary of our model
summary(model=model, 
        input_data={"input_ids": torch.randint(low=0, high=25, size=(1, 512), dtype=torch.long), "mask_ids":torch.ones(size=(1, 512), dtype=torch.long)},
        col_names=['input_size', 'output_size', 'trainable', 'num_params'], 
        row_settings=['var_names'], 
        col_width=20)

Layer (type (var_name))                                      Input Shape          Output Shape         Trainable            Param #
FT_DistilBERT (FT_DistilBERT)                                --                   [1, 4]               True                 --
├─DistilBertModel (block_1)                                  --                   [1, 512, 768]        True                 --
│    └─Embeddings (embeddings)                               [1, 512]             [1, 512, 768]        True                 --
│    │    └─Embedding (word_embeddings)                      [1, 512]             [1, 512, 768]        True                 23,440,896
│    │    └─Embedding (position_embeddings)                  [1, 512]             [1, 512, 768]        True                 393,216
│    │    └─LayerNorm (LayerNorm)                            [1, 512, 768]        [1, 512, 768]        True                 1,536
│    │    └─Dropout (dropout)                                [1, 512, 768]        [1, 512,

# 8. Select loss function, optimizer, and accuracy metrics

In [28]:
# Loss function
loss_fn = nn.CrossEntropyLoss()

In [29]:
# Optimizer
LEARNING_RATE = 1e-05
optimizer = torch.optim.Adam(params=model.parameters(),
                             lr=LEARNING_RATE)

In [83]:
# Accuracy metrics: Accuracy, Confusion Matrix
accuracy_fn = torchmetrics.Accuracy(task='multiclass', num_classes=4)
# conf_matrix = torchmetrics.ConfusionMatrix(task='multiclass', num_classes=4)
precision = torchmetrics.Precision(task='multiclass', num_classes=4)
recall = torchmetrics.Recall(task='multiclass', num_classes=4)

In [84]:
# To plot the confusion matrix, will install mlxtend
# !pip install mlxtend

In [85]:
# mlxtend has a good-looking function to visualize the confusion matrix with color scaling
# import mlxtend 

# 9. Set up the Training Loop

Here I will do the following:
1. Create a train step
2. Create a test/validation step
3. Create a combined training loop that involves both steps 1 and 2

In [82]:
metrics_of_choice = [accuracy_fn, precision, recall]
metrics_of_choice

[MulticlassAccuracy(), MulticlassPrecision(), MulticlassRecall()]

In [76]:
# 1. Train step

def train_step(model: torch.nn.Module, 
               train_dataloader: torch.utils.data.DataLoader, 
               loss_fn: torch.nn.Module, 
               optimizer: torch.optim, 
               metrics: List[torchmetrics], 
               device: torch.device):
    """
    This function conducts one training loop across all batches in a DataLoader

    Parameters:
        - model: a PyTorch model architecture that will be trained
        - train_dataloader: a PyTorch DataLoader that contains the training batches
        - loss_fn: a PyTorch loss function through which we measure how much the model errors
        - optimizer: a PyTorch optimizer that determines how we adjust the weights of the model
        - metrics: a list of PyTorch (torchmetrics) metrics that we want to review as the model progresses through training
        - device: the target device in which we will train the model (e.g., CPU, GPU)
    """
    # 0. Set up variables that will contain the training loss, and accuracy metric
    total_train_loss = 0
    total_train_acc = 0
    total_train_prec = 0
    total_train_rec = 0
    
    ## I also want to keep a record of how the loss & accuracy metrics develop per each batch in an epoch - in case we want to plot the progress
    results_dict = {
        'batch_train_loss': [],
        'batch_train_acc': [],
        'batch_train_prec': [],
        'batch_train_rec': []
    } 
    
    
    # 1. Set the model in training mode
    model.train()

    # 2. Send model to target device
    model.to(device)

    # 3. Start the training step
    for idx, data in enumerate(train_dataloader):
        
        # 3.1 Get data into target device
        inputs = data['ids'].to(device)
        mask = data['mask'].to(device)
        targets = data['targets'].to(device)

        # 3.2 Run model on data
        outputs = model(input_ids=inputs,
                        mask_ids=mask)
        probabilities = torch.softmax(outputs,
                                      dim=1)
        predictions = torch.argmax(probabilities,
                                   dim=1)

        # 3.3 Calculate the loss & accuracy metrics
        ## Loss
        loss = loss_fn(outputs, targets)
        total_train_loss += loss
        results_dict['batch_train_loss'].append(loss) 

        ## Metrics
        for metric in metrics:
            
            if metric == MulticlassAccuracy():
                accuracy = metric(predictions, targets)
                total_train_acc += accuracy
                results_dict['batch_train_acc'].append(accuracy)
            
            elif metric == MulticlassPrecision():
                precision = metric(predictions, targets)
                total_train_prec += precision
                results_dict['batch_train_prec'].append(precision)
            
            elif metric == MulticlassRecall():
                recall = metric(predictions, targets)
                total_train_rec += recall
                results_dict['batch_train_rec'].append(recall)

        
        # 3.4 Optimizer zero grad
        optimizer.zero_grad()

        # 3.5 Do backpropagation
        loss.backward()

        # 3.6 Apply optimizer step
        optimizer.step()


        # Print out what's happening every 100th batch - we get the loss & metrics results per batch by dividing by the number of batches:
        if idx % 100 == 0:
            print(f"Training loss as at the {idx}-th batch: {total_train_loss/(idx+1):.3f}\n"
                  f"Train accuracy as at the {idx}-th batch: {total_train_acc/(idx+1)*100:.2f}%\n"
                  f"Train precision as at the {idx}-th batch: {total_train_prec/(idx+1):.3f}\n"
                  f"Train recall as at the {idx}-th batch: {total_train_rec/(idx+1):.3f}")
            # print(f"Train accuracy as at the {idx}-th batch: {train_acc/(idx+1)*100:.2f}%")
            # print(f"Train precision as at the {idx}-th batch: {train_prec/(idx+1):.3f}")
            # print(f"Train recall as at the {idx}-th batch: {train_rec/(idx+1):.3f}")

    # 4. Get train loss/accuracy/precision/recall per batch
    total_train_acc = total_train_acc / len(train_dataloader)
    total_train_prec = total_train_prec / len(train_dataloader)
    total_train_rec = total_train_rec / len(train_dataloader)
    
    return total_train_loss, total_train_acc, total_train_prec, total_train_rec, results_dict

In [63]:
# 2. Test step

def test_step(model, 
              test_dataloader: torch.utils.data.DataLoader, 
              loss_fn: torch.nn.Module, 
              metrics: List[torchmetrics], 
              device: torch.device):
    """
    This function conducts one validation loop across all batches in a DataLoader

    Parameters:
        - model: a PyTorch model architecture that will be trained
        - test_dataloader: a PyTorch DataLoader that contains the training batches
        - loss_fn: a PyTorch loss function through which we measure how much the model errors
        - metrics: a list of PyTorch (torchmetrics) metrics that we want to review as the model progresses through training
        - device: the target device in which we will train the model (e.g., CPU, GPU)
    """
    # 0. Set up variables that will contain the training loss, and accuracy metric
    total_test_loss = 0
    total_test_acc = 0
    total_test_prec = 0
    total_test_rec = 0
    
    ## To keep a record of how the loss & accuracy metrics develop per each batch in an epoch
    results_dict = {
        'batch_test_loss': [],
        'batch_test_acc': [],
        'batch_test_prec': [],
        'batch_test_rec': []
    } 
    
    # 1. Set the model in evaluation mode
    model.eval()

    # 2. Send model to target device
    model.to(device)

    
    # 3. Start validation loop
    
    ## 3.1 Set model in inference mode
    with torch.inference_mode():

        for idx, data in enumerate(test_dataloader):
            
            ## 3.2 Get data into target device
            inputs = data['ids'].to(device)
            mask = data['mask'].to(device)
            targets = data['targets'].to(device)

            ## 3.3 Get predictions
            outputs = model(input_ids=inputs,
                            mask_ids=mask)
            probabilities = torch.softmax(outputs,
                                          dim=1)
            predictions = torch.argmax(probabilities,
                                       dim=1)

            ## 3.4 Estimate loss & metrics
            ### Loss
            loss = loss_fn(outputs, targets)
            total_test_loss += loss
            results_dict['batch_test_loss'].append(loss) 
    
            ### Metrics
            for metric in metrics:
                
                if metric == MulticlassAccuracy():
                    accuracy = metric(predictions, targets)
                    total_test_acc += accuracy
                    results_dict['batch_test_acc'].append(accuracy)
                
                elif metric == MulticlassPrecision():
                    precision = metric(predictions, targets)
                    total_test_prec += precision
                    results_dict['batch_test_prec'].append(precision)
                
                elif metric == MulticlassRecall():
                    recall = metric(predictions, targets)
                    total_test_rec += recall
                    results_dict['batch_test_rec'].append(recall)

            # Print out what's happening every 100th batch - we get the loss & metrics results per batch by dividing by the number of batches:
            if idx % 100 == 0:
                print(f"Testing loss as at the {idx}-th batch: {total_test_loss/(idx+1):.3f}\n"
                      f"Test accuracy as at the {idx}-th batch: {total_test_acc/(idx+1)*100:.2f}%\n"
                      f"Test precision as at the {idx}-th batch: {total_test_prec/(idx+1):.3f}\n"
                      f"Test recall as at the {idx}-th batch: {total_test_rec/(idx+1):.3f}")

    # 4. Get test loss/accuracy/precision/recall per batch
    total_test_loss = total_test_loss / len(test_dataloader)
    total_test_acc = total_test_acc / len(test_dataloader)
    total_test_prec = total_test_prec / len(test_dataloader)
    total_test_rec = total_test_rec / len(test_dataloader)

    return total_test_loss, total_test_acc, total_test_prec, total_test_rec, results_dict

In [86]:
# 3. Consolidated Training Loop

def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader, 
          test_dataloader: torch.utils.data.DataLoader, 
          epochs: int, 
          loss_fn: torch.nn.Module, 
          optimizer: torch.optim, 
          metrics: List[torchmetrics],
          device: torch.device):
    """
    This function combines the train_step() and test_step() functions we created above, to provide a consolidated training loop that includes both training & validation for a 
    specified number of epochs.

    Arguments:
        - model: a PyTorch model architecture that will be trained
        - train_dataloader: a PyTorch DataLoader that contains the training batches
        - test_dataloader: a PyTorch DataLoader that contains the testing batches
        - epochs: the number of epochs for which we will train the model (i.e., how many full iterations through the train & test DataLoaders)
        - loss_fn: a PyTorch loss function through which we measure how much the model errors
        - optimizer: a PyTorch optimizer that determines how we adjust the weights of the model
        - metrics: a list of PyTorch (torchmetrics) metrics that we want to review as the model progresses through training
        - device: the target device in which we will train the model (e.g., CPU, GPU)
    """
    consolidated_results_dict = {}

    # Get start time for model - want to check the time it takes to train, although SageMaker measures it as well (just for comparison)
    train_start_time = timer()

    for epoch in range(epochs):
        print(f"Epoch {epoch+1} begins")
        
        # Training loop
        train_loss_epoch, train_acc_epoch, train_prec_epoch, train_rec_epoch, consolidated_results_dict[f'Epoch {epoch} training results'] = train_step(model=model,
                                                                                                                                                        train_dataloader=train_dataloader,
                                                                                                                                                        loss_fn=loss_fn,
                                                                                                                                                        optimizer=optimizer,
                                                                                                                                                        metrics=metrics,
                                                                                                                                                        device=device)
        

        # Testing loop
        test_loss_epoch, test_acc_epoch, test_prec_epoch, test_rec_epoch, consolidated_results_dict[f'Epoch {epoch} testing results'] = test_step(model=model,
                                                                                                                                                  test_dataloader=test_dataloader,
                                                                                                                                                  loss_fn=loss_fn,
                                                                                                                                                  metrics=metrics,
                                                                                                                                                  device=device)
        
        # Print out some results
        print(f"Epoch {epoch+1} ends - here are the results:")
        print(f"Train Loss: {train_loss_epoch:.3f}\nTrain Acc: {train_acc_epoch*100:.2f}%\nTrain Precision: {train_prec_epoch:.3f}\nTrain Recall: {train_rec_epoch:.3f}")
        print(f"Test Loss: {test_loss_epoch:.3f}\nTest Acc: {test_acc_epoch*100:.2f}%\nTest Precision: {test_prec_epoch:.3f}\nTest Recall: {test_rec_epoch:.3f}")
        print(f"-"*100)

    # Get end time for model
    train_end_time = timer()

    print(f"Time to train model: {train_end_time - train_start_time:.2f}/60 minutes")

    

# 10. Setup the `main()` function that will start everything!

In [87]:
def main():
    print("Start Training")

    # 1. Setup device-agnostic code
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 2. Create argument parser that will provide the hyperparameters to the model
    parser = argparse.ArgumentParser()

    ## Argument for number of epochs
    parser.add_argument("--epochs", type=int, default=2)
    ## Argument for train batch size
    parser.add_argument("--train_batch_size", type=int, default=4) # we have already created the dataloaders, so won't be using this - could have added everything here in the main() function
    ## Argument for test batch size
    parser.add_argument("--test_batch_size", type=int, default=2) # we have already created the dataloaders, so won't be using this - could have added everything here in the main() function
    ## Argument for the learning rate of the optimizer
    parser.add_argument("--learning_rate", type=float, default=5e-05)

    args = parser.parse_args()

    # 3. Get tokenizer
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

    # 4. Initialize a model instance
    model = FT_DistilBERT(num_classes=4)

    # 5. Choose loss function, optimizer & accuracy metrics
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(params=model.parameters(), 
                                 lr=args.learning_rate)
    accuracy_fn = torchmetrics.Accuracy(task='multiclass', num_classes=4)
    precision_fn = torchmetrics.Precision(task='multiclass', num_classes=4)
    recall_fn = torchmetrics.Recall(task='multiclass', num_classes=4)
    metrics_list = [accuracy_fn, precision_fn, recall_fn]

    # 6. Train model
    train(model=model,
          train_dataloader=train_dataloader,
          test_dataloader=test_dataloader,
          epochs=args.epochs,
          loss_fn=loss_fn,
          optimizer=optimizer,
          metrics=metrics_list,
          device=device)

    # 7. Specify output directory
    output_dir = os.environ['SM_MODEL_DIR']
    print(f"Output directory: {output_dir}")

    output_model_file = os.path.join(output_dir, 'pytorch_distilbert_model_news.bin')
    output_vocab_file = os.path.join(output_dir, 'vocab_distilbert_news.bin')

    # 8. Save model weights & vocabulary
    torch.save(obj=model.state_dict(),
               f=output_model_file)
    tokenizer.save_vocabulary(save_directory=output_vocab_file)

In [None]:
# To run the code:
#if __name__ == '__main__':
#    main()