# Practicum 07 - Text Classificaiton with Transformer Encoders

In this practicum, we will use the [Medical Text Dataset](https://www.kaggle.com/datasets/chaitanyakck/medical-text) (see also [here](https://github.com/sebischair/Medical-Abstracts-TC-Corpus)) to build a classificaiton model that identifies which of five possible disease areas is discussed in an input medical abstract. The models will use only the abstract text as input. We will use [PyTorch](https://pytorch.org/) and [Pytorch Lightning](https://lightning.ai/docs/pytorch/stable/) along with [torchtext](https://pytorch.org/text/stable/index.html) to build the models.

We will illulstrate many of the concepts needed to address working with text data including tokenization and token embedding. As in the previous practicums, we will first demonstrate the techniques on a non biomedical dataset. We will then apply those to the _Medical Text Dataset_.

In [None]:
# Google Colab setup
# mount the google drive - this is necessary to access supporting src
from google.colab import drive
drive.mount("/content/drive")

In [None]:
# install any packages not found in the Colab environment
!pip install lightning
!pip install 'portalocker>=2.0.0'

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from torchtext.datasets import AG_NEWS
import torchtext
from torchtext.functional import to_tensor
import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch import seed_everything
import lightning.pytorch.trainer as trainer
import torchmetrics as TM
import torchmetrics as TM
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report

In [None]:
dir_dataroot = "/content/drive/MyDrive/Colab Notebooks/CPSC-8810-ML-BioMed/data"

dir_lightning = "/content/drive/MyDrive/Colab Notebooks/CPSC-8810-ML-BioMed/lightning"

rs = 123456 # random seed for everything

# PyTorch Lightning Demonstration on AG_NEWS

We will begin by creating a PTL model to classify news articles in the [AG News](https://paperswithcode.com/dataset/ag-news) data set, one of the many datasets available in the [torchtext](https://pytorch.org/text/stable/datasets.html) library. The _AG News_ dataset contains 7,500 training articles and 475 test articles divided into 4 classes. In this example, we will build a deep learing model with an encoder composed of a token embedding layer and 2 transformer layers, followed by a feed forward layer to predict class membership for an input article.

First, we need to create a directory for the data. We will create the directory `../../data` which is ignored by git. We also need a direcory to save trained models. By default, PTL will save versions of the model during training called __checkpoints__ as discussed below. Rather than saving these in the current directory, we will a create directory `../../lightning` which is also ignored by git.

# Preprocessing Text Data
Unlike other data types, text data is somewhat unique in that it is inherently non-numerical. To facilate its use in our machine learning models (or any computational model), we need to convert the input text to a numerical representation. We will do this by:
1. Tokening all samples in the dataset (known as a corpus in NLP)
2. Forming a fixed length vocabulary composed of the unique tokens with an integer index assigned to each token
3. Converting an input text to its token index representation
4. Formulating low dimensional numerical vectors (embeddings) for each token

Ultimately, we will package all of these steps into a PyTorch Lightning Data Module which will handle these processing steps for use during training and testing. But first, let's look at the steps in detail.

Let's start by viewing a sample from the _AG News_ dataset:

In [None]:
print(next(iter(AG_NEWS(split="train"))))

Next, we can tokenize the _AG News_ samples using the [torchtext tokenizer utility](https://pytorch.org/text/stable/data_utils.html). We can also create a simple Python generator to combine with the [torchtext build_vocab_from_iterator](https://pytorch.org/text/stable/vocab.html#build-vocab-from-iterator) function to build our _AG News_ vocabulary.

In [None]:
tokenizer = get_tokenizer("basic_english")
train_iter = AG_NEWS(split="train")

# simple generator to yield list of tokens for all samples
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

# build the vocabulary. We add special tokens <unk> and <pad> to handle unknown words for new test samples and padding for model input
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
vocab.set_default_index(vocab["<unk>"])
print("Tokens in vocabulary:", len(vocab))
# we see that the vocabulary is a dictionary like object with the token as key and the index as value
pad_idx = vocab['<pad>']
print(pad_idx)

Our data module will need to convert input text samples to their respective token index representations during the training, validation, and test steps. Importantly, the samples will be provided in batches. Thus, we will need our training, validation, and test data handlers (which are stored in the data module) to handle this batch conversion from text to token index. This can be done by providing an appropriate function to the data handlers using the `collate_fn` input argument. Recall, that in our ligtning modules we implement `train_step`, `val_step` and `test_step` functions that take a `batch` input that is a tuple containing the input `x` and corresponding label `y` (at least for classificaiton problems). We would like our `collate_fn` to provide the `batch` input in this format.

To simplify training, we will require that all inputs are the same length. One way to do this is to process all samples in the dataset and find the longest one (i.e., the one with the most tokens). For any input that is shorter than this one, we will add `<pad>` tokens to make it the same length as the longest sample.

To simplify the process, it will useful to have a function that converts a text input to its token index representation. Also, becuase the sample labels are provided as strings that start at 1, we want to convert them to integers starting at 0. We can accomplish both of these tasks with the `lambda` functions.

In [None]:
# function to convert text to index representation
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

Let's look at the resulting token represenation for the first sample in the _AG News_ dataset:

In [None]:

label, text = next(iter(AG_NEWS(split="train")))
print(text_pipeline(text))
label_pipeline(label)

Next, let's find the sample with the most tokens in our input.

In [None]:
max_tokens = 0
cnt = 0
for label, text in AG_NEWS(split="train"):
    l = len(text_pipeline(text))
    if l > max_tokens:
        max_tokens = l
    cnt += 1
print("Longest training text sample:", max_tokens)
print("Number of training samples:", cnt)

## Let's create a reusable Data Module Class
Now that we have seen most of the pieces we need to prepare our text intput, let's assemble the data module. We will first need to define custom PyTorch `Dataset` classes to handle loading the samples. Our classes extend the `torch.utils.data.Dataset` class and are required to implement the `__len__()` method which should return the number of samples in the dataset and the `__getitem__()` method which returns the sample at index `idx`. We develop two such classes:
1. `TextMapperDataset` - handles datasets provided by [torchtext.datasets](https://pytorch.org/text/stable/datasets.html) without requiring us to use the [torchdata project](https://pytorch.org/data/beta/index.html) which is still in Beta at the time of this writing.
2. `TextFileDataset` - handles text data stored in a .csv file where it is assumed that each row is of the form _label, sample_.

Finally, we implement the `TextDataModule` class which extends the PyTorch `LightningDataModule` class and can be used with either text file for torchtext datasets.

In [None]:
class TextMapperDataset(torch.utils.data.Dataset):
    def __init__(self, data_pipe):
        super().__init__()
        self.data_pipe = data_pipe
        self.samples = self.load_to_memory()
        self.length = len(self.samples)

    def load_to_memory(self):
        samples = []
        for _label, _text in self.data_pipe:
            samples.append((_text, _label))
        return samples

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.samples[idx]

class TextFileDataset(TextMapperDataset):
    def __init__(self, file_path, skip_header=True):
        self.file_path = file_path
        self.samples = self.load_to_memory(skip_header)
        self.length = len(self.samples)

    def load_to_memory(self, skip_header):
        samples = []
        with open(self.file_path, 'r') as f:
            if skip_header:
                next(f)
            for line in f:
                label, text = line.split(',',maxsplit=1)
                samples.append((text.replace('\n', '').replace('"',''), label))
        return samples


class TextDataModule(L.LightningDataModule):
    def __init__(self, data_source, val_fraction = 0.1,
                 tokenizer="basic_english", batch_size=16, embedding_dim=64, max_input_length=512,
                 class_name_map=None):
        """
        data_source: str or callable, if str, it is a path to a directory with train.csv and test.csv files
        val_fraction: float, fraction of training data to use for validation
        tokenizer: str, tokenizer to use
        batch_size: int, batch size for training
        embedding_dim: int, dimension of word embeddings
        max_input_length: int, maximum number of tokens to use in input. The overall maximum input sequence will be set to the
        minimum of (truncate longer texts, pad shorter ones)
        class_name_map: dict, mapping of class index to class name
        """
        super().__init__()
        self.data_source = data_source
        self.batch_size = batch_size
        self.val_fraction = val_fraction
        self.tokenizer = get_tokenizer(tokenizer)
        self.embedding_dim = embedding_dim
        self.max_tokens = None
        self.vocab = None
        self.max_input_length = max_input_length
        self.class_name_map = class_name_map

    def build_vocab(self, data_iterable):
        def yield_tokens(data_iter):
            for text, _ in data_iter:
                yield self.tokenizer(text)
        self.vocab = build_vocab_from_iterator(yield_tokens(data_iterable), specials=["<unk>", "<pad>"])
        self.vocab.set_default_index(self.vocab["<unk>"])
        self.padding_index = self.vocab['<pad>']
        return self.vocab

    def max_tokens_in(self, data_iterable):
        if self.vocab is None:
            self.build_vocab(data_iterable)
        text_to_tokens = lambda x: self.vocab(self.tokenizer(x))
        max_tokens = 0
        for text, _ in data_iterable:
            l = len(text_to_tokens(text))
            if l > max_tokens:
                max_tokens = l
        return max_tokens

    def max_sample_length(self):
        return min(self.max_input_length, self.max_tokens)

    def label_pipeline(self, x):
        return int(x) - 1

    def text_pipeline(self, x):
        return self.vocab(self.tokenizer(x))

    def collate_batch(self, batch):
        """
        Collate function to convert a batch of text samples to a tensor of input tokens and a tensor of labels
        """
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        label_list = []
        max_input = min(self.max_tokens, self.max_input_length)
        text_batch = torch.zeros((len(batch), max_input), dtype=torch.int64)
        pad_idx = self.padding_index
        cnt = 0
        for _text, _label in batch:
            label_list.append(self.label_pipeline(_label))
            text = torch.tensor(self.text_pipeline(_text), dtype=torch.int64)
            if len(text) > max_input:
                text = text[:max_input]
            elif len(text) < max_input:
                text = torch.concat((text, torch.tensor([pad_idx] * (max_input - len(text)), dtype=torch.int64)))
            text_batch[cnt] = text
            cnt += 1
        label_list = torch.tensor(label_list, dtype=torch.int64)
        return text_batch.to(device), label_list.to(device)

    def setup(self, stage=None):
        if type(self.data_source)==str and os.path.exists(self.data_source): # treat data source as a directory, expect a train and test file
            dataset = TextFileDataset(os.path.join(self.data_source, 'train.csv'))
            self.max_tokens = self.max_tokens_in(dataset)
            test_dataset = TextFileDataset(os.path.join(self.data_source, 'test.csv'))
        else:
            dataset = TextMapperDataset(self.data_source(split='train'))
            self.max_tokens = self.max_tokens_in(dataset)
            test_dataset = TextMapperDataset(self.data_source(split='test'))
        self.build_vocab(dataset)
        n_data = len(dataset)
        n_train = int((1-self.val_fraction) * n_data)
        n_val = n_data - n_train

        train_dataset, val_dataset = torch.utils.data.random_split(dataset, [n_train, n_val])
        # self.max_tokens = self.max_tokens_in(dataset)

        self._train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.collate_batch)
        self._val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_batch)
        self._test_dataloader = DataLoader(test_dataset, batch_size=self.batch_size, collate_fn=self.collate_batch)

    def train_dataloader(self):
        return self._train_dataloader

    def test_dataloader(self):
        return self._test_dataloader

    def val_dataloader(self):
        return self._val_dataloader

Now that we have a data modoule class, let's create an `TextDataModule` instance for the _AG News_ data.

In [None]:
# load the AG_NEWS dataset
seed_everything(rs)
dm = TextDataModule(AG_NEWS, batch_size=16, max_input_length=256, class_name_map={0: 'World', 1: 'Sports', 2: 'Business', 3: 'Sci/Tech'})
dm.setup()

We can examine the dataloaders contained in the `dm` data module to see how many batches are present and verify that it creates the expected token index representations for input samples.

In [None]:
print(len(dm.train_dataloader()))
print(len(dm.val_dataloader()))
print(len(dm.test_dataloader()))

In [None]:
x=next(iter(dm.train_dataloader()))
print(x[0].shape, x[1].shape)
print(x[0][0])
print(x[1])

# Text Classification Model
Now we will build a text classificaiton model. Similar to the approach we used for CNN classification of images, we will develop a model with an encoder component and a classifier component. In the encoder, we'll first use an _Embedding_ layer to transform the token index representations of our inputs to a low dimensional vector representation for each token and then use a _Transformer Encoder_ block to further refine the representations using attention mechansisms. The classifier component will pass the Transformer output through a feed forward layer to generate the logits.

We will make these components modular by including inputs to the `__init__` constructor to allow us to resuse these classes for models with different parameters.

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_heads,
                 num_transformer_layers=2, dim_feedforward=1024, activation='relu', dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.te_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
                                              dim_feedforward=dim_feedforward, dropout=dropout,
                                              activation=activation, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer=self.te_layer, num_layers=num_transformer_layers)

    def forward(self, x):
        x = self.embedding(x)
        x = self.transformer_encoder(x)
        return x

# Test the ImageEncoder with random input data
encoder = TextEncoder(len(dm.vocab), dm.embedding_dim, num_heads=4)
input_text = torch.randint(0, len(dm.vocab), (dm.batch_size, dm.max_sample_length()))
print("input shape",input_text.shape)
output_features = encoder(input_text)
print("Output shape:", output_features.shape)

In [None]:
class TextClassifier(nn.Module):
    def __init__(self, embedding_dim, seq_length, num_classes):
        super().__init__()
        self.flatten = nn.Flatten()  # Flatten the input tensor
        self.linear = nn.Linear(seq_length*embedding_dim , num_classes)  # Linear layer with 10 output units

    def forward(self, x):
        x = self.flatten(x)  # Flatten the input tensor
        x = self.linear(x)   # Pass through the linear layer
        return x

# Test the module with random input data
classifier = TextClassifier(dm.embedding_dim, dm.max_sample_length(), len(dm.class_name_map))
input_tensor = torch.randn(dm.batch_size, dm.max_sample_length(), dm.embedding_dim)  # Batch size of 5, input tensor shape [5, 256, 64]
print(input_tensor.shape)
output = classifier(input_tensor)
print("Output shape:", output.shape)  # Expected output shape: [4, 10]

Now we will combine our encoder and classifier to form the overall model using the PyTorch LightningModule class as we did in the previous practicum. In fact, although we've changed the class name to `TextClassifierModel`, all of the code is the same as the image classifier we created in the previous lecture. The changes are all contained in the encoder and classifier components.

In [None]:
class TextClassifierModel(L.LightningModule):
    def __init__(self, encoder, classifier, num_classes):
        super().__init__()
        # model layers
        self.encoder = encoder
        self.classifier = classifier

        # validation metrics
        self.val_metrics_tracker = TM.wrappers.MetricTracker(TM.MetricCollection([TM.classification.MulticlassAccuracy(num_classes=num_classes)]))
        self.validation_step_outputs = []
        self.validation_step_targets = []

        # # test metrics
        self.test_roc = TM.ROC(task="multiclass", num_classes=num_classes, thresholds=list(np.linspace(0.0, 1.0, 20))) # roc and cm have methods we want to call so store them in a variable
        self.test_cm = TM.ConfusionMatrix(task='multiclass', num_classes=num_classes)
        self.test_metrics_tracker = TM.wrappers.MetricTracker(TM.MetricCollection([TM.classification.MulticlassAccuracy(num_classes=num_classes),
                                                            self.test_roc, self.test_cm]))
        self.test_step_outputs = []
        self.test_step_targets = []

    def forward(self, x):
        x = self.encoder(x)
        x = self.classifier(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.log('val_loss', loss, on_step=True, on_epoch=True)

        # store the outputs and targets for the epoch end step
        self.validation_step_outputs.append(logits)
        self.validation_step_targets.append(y)
        return loss

    def on_validation_epoch_end(self):
        # stack all the outputs and targets into a single tensor
        all_preds = torch.vstack(self.validation_step_outputs)
        all_targets = torch.hstack(self.validation_step_targets)

        # compute the metrics
        loss = nn.functional.cross_entropy(all_preds, all_targets)
        self.val_metrics_tracker.increment()
        self.val_metrics_tracker.update(all_preds, all_targets)
        self.log('val_loss_epoch_end', loss)

        # clear the validation step outputs
        self.validation_step_outputs.clear()
        self.validation_step_targets.clear()

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.log('test_loss', loss, on_step=True, on_epoch=True)
        self.test_step_outputs.append(logits)
        self.test_step_targets.append(y)
        return loss

    def on_test_epoch_end(self):
        all_preds = torch.vstack(self.test_step_outputs)
        all_targets = torch.hstack(self.test_step_targets)

        self.test_metrics_tracker.increment()
        self.test_metrics_tracker.update(all_preds, all_targets)
        # clear the test step outputs
        self.test_step_outputs.clear()
        self.test_step_targets.clear()

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

## Model Training
Now we are ready to train and evaluate our model. We follow the same procedure as the previous practicum, using the PyTorch Lightning `Trainer` class.

In [None]:
seed_everything(rs)
encoder = TextEncoder(len(dm.vocab), dm.embedding_dim, num_heads=4)
classifier = TextClassifier(dm.embedding_dim, dm.max_sample_length(), len(dm.class_name_map))
agnews_model = TextClassifierModel(encoder, classifier, num_classes=len(dm.class_name_map))

trainer = L.Trainer(default_root_dir=dir_lightning,
                    max_epochs=5,
                    callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model=agnews_model, train_dataloaders=dm.train_dataloader(), val_dataloaders=dm.val_dataloader())

## Validation Set Accuracy
Next, let's examine the validation set accuracy.

In [None]:
mca = agnews_model.val_metrics_tracker.compute_all()['MulticlassAccuracy']
plt.plot(range(1, len(mca)+1), mca, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Epoch Validation Accuracy')
plt.grid()

# Test Set Performance
Now that we've trained the model, we're ready to examine the test set accuracy. Here, we use the PyTorch Lightning Trainer module's _test_ function passing in the test dataloader from our datamodule.

In [None]:
trainer.test(model=agnews_model, dataloaders=dm.test_dataloader())

Next, we gather the performance metrics from the model's `test_metrics_tracker` attribute using the `compute` method.

In [None]:
rslt = agnews_model.test_metrics_tracker.compute()

Now, we can plot the confusion matrix and the class level ROC curves to assess performance.

In [None]:
cmp = sns.heatmap(rslt['MulticlassConfusionMatrix'], annot=True, fmt='d', cmap='Blues')
cmp.set_xlabel('Predicted Label')
cmp.set_xticklabels(dm.class_name_map.values(), rotation=90)
cmp.set_yticklabels(dm.class_name_map.values(), rotation=0)
cmp.set_ylabel('Actual Label');

In [None]:
fpr, tpr, thresholds = rslt['MulticlassROC']
for i in range(len(dm.class_name_map)):
    plt.plot(fpr[i], tpr[i], label=dm.class_name_map[i])
plt.xlabel('False Positive Rate')
plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.ylabel('True Positive Rate')
plt.legend()
plt.grid()

Finally, let's get the classification report.

In [None]:
# Print the classification report
device = torch.device("cpu")   #"cuda:0"
agnews_model.eval()
y_true=[]
y_pred=[]
with torch.no_grad():
    for test_data in dm.test_dataloader():
        test_samples, test_labels = test_data[0].to(device), test_data[1].to(device)
        pred = agnews_model(test_samples).argmax(dim=1)
        for i in range(len(pred)):
            y_true.append(test_labels[i].item())
            y_pred.append(pred[i].item())

print(classification_report(y_true,y_pred,target_names=list(dm.class_name_map.values()),digits=4))

# Medical Abstract Classificaiton
Now let's tryout our model on the [medical abstract dataset](https://www.kaggle.com/datasets/chaitanyakck/medical-text/data). This dataset is much smaller (about 1/10 the size) of the _AG News_ dataset, so we should not expect to see the same level of performance.

# Problem 1 (2 points)
In the code cell below, use the `TextDataModule` class to create a data module for the medical abstract dataset using the train.csv and test.csv files in the _../../data/medical-abstracts-tc-corpus_ directory. In your inputs to the `TextDataModule` constructor, set
````
batch_size=16
max_input_length=512
embedding_dim=64
````

In [None]:
seed_everything(rs)
medical_dir = os.path.join(dir_dataroot,"medical-abstracts-tc-corpus")
class_map = {0: 'Neoplasms', 1: 'Digestive System', 2: 'Nervous System', 3: 'Cardiovasicular', 4: 'Generic'}

########## YOUR CODE HERE ############
dm = None

########## YOUR CODE HERE ############

Let's see how many batches are in the training, validation, and test datasets.

In [None]:
print(len(dm.train_dataloader()))
print(len(dm.val_dataloader()))
print(len(dm.test_dataloader()))

# Problem 2 (2 points)
In the code cell below, create a TextEncoder, TextClassifier, and TextClassifierModel and then train the model. In the encoder, set `num_heads=2`. The other model attributes should be set using the DataModule variable, `dm` in the same way they were for the _AG News_ model above.

In [None]:
seed_everything(rs)
########## START YOUR CODE HERE ############
encoder = None
classifier = None
medical_model = None
########## END YOUR CODE HERE ############

trainer = L.Trainer(default_root_dir=dir_lightning,
                    max_epochs=50,
                    callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=5)])
trainer.fit(model=medical_model, train_dataloaders=dm.train_dataloader(), val_dataloaders=dm.val_dataloader())

## Validation Set Accuracy
Now let's examine the validation set accuracy.

In [None]:
mca = medical_model.val_metrics_tracker.compute_all()['MulticlassAccuracy']
plt.plot(range(1, len(mca)+1), mca, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Epoch Validation Accuracy')
plt.grid()

## Test Set
Finally, let's examine the test set accuracy including the confusion matrix, ROC curves, and print the classification report.

In [None]:
trainer.test(model=medical_model, dataloaders=dm.test_dataloader())

In [None]:
rslt = medical_model.test_metrics_tracker.compute()

In [None]:
cmp = sns.heatmap(rslt['MulticlassConfusionMatrix'], annot=True, fmt='d', cmap='Blues')
cmp.set_xlabel('Predicted Label')
cmp.set_xticklabels(dm.class_name_map.values(), rotation=90)
cmp.set_yticklabels(dm.class_name_map.values(), rotation=0)
cmp.set_ylabel('Actual Label');

In [None]:
fpr, tpr, thresholds = rslt['MulticlassROC']
for i in range(len(dm.class_name_map)):
    plt.plot(fpr[i], tpr[i], label=dm.class_name_map[i])
plt.xlabel('False Positive Rate')
plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.ylabel('True Positive Rate')
plt.legend()
plt.grid()

In [None]:
# Print the classification report
device = torch.device("cpu")   #"cuda:0"
medical_model.eval()
y_true=[]
y_pred=[]
with torch.no_grad():
    for test_data in dm.test_dataloader():
        test_samples, test_labels = test_data[0].to(device), test_data[1].to(device)
        pred = medical_model(test_samples).argmax(dim=1)
        for i in range(len(pred)):
            y_true.append(test_labels[i].item())
            y_pred.append(pred[i].item())

print(classification_report(y_true,y_pred,target_names=list(dm.class_name_map.values()),digits=4))

# Problem 3 (2 points)
In the markdown cell below, provide your interpretation of the medical abstract classification model performance. Specifically, why do you think the model performance is poor compared to the _AG News_ model. What do you notice in the confusion matrix and the ROC curves that support your interpretation?