In [1]:
import torch
import transformers
import datasets
import sklearn

In [2]:
# Load GPT-2 tokenizer and model
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2', device_map='auto')
# Add a padding token to GPT-2 tokenizer (since it doesn't have one by default)
tokenizer.pad_token = tokenizer.eos_token

In [3]:
# Load a small subset of the IMDb dataset for binary sentiment classification
dataset = datasets.load_dataset('imdb', split='train[:50%]')
dataset = dataset.train_test_split(test_size=0.2)

In [4]:
# Annotate with whether a text contains the word 'plot'
def set_word_label(example):
    if 'plot' in example['text'].lower():
        example['label'] = 1
    else:
        example['label'] = 0
    return example

dataset = dataset.map(set_word_label)

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2500 [00:00<?, ? examples/s]

In [5]:
# Tokenization function
def tokenize_function(examples):
    return tokenizer(
        examples['text'], padding='max_length', truncation=True, max_length=128
    )

# Apply the tokenization
tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Remove unnecessary columns and set format for PyTorch
tokenized_datasets = tokenized_datasets.remove_columns(['text'])
tokenized_datasets.set_format('torch')

# Split into train and evaluation datasets
train_dataset = tokenized_datasets['train']
eval_dataset = tokenized_datasets['test']

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2500 [00:00<?, ? examples/s]

In [6]:
# Define a custom model with GPT-2 as feature extractor and a linear classifier on top
class GPT2ForClassification(torch.nn.Module):
    def __init__(self, gpt2, num_labels):
        super(GPT2ForClassification, self).__init__()
        self.gpt2 = gpt2
        # Freeze GPT-2 parameters
        for param in self.gpt2.parameters():
            param.requires_grad = False
        # Linear classifier
        self.classifier = torch.nn.Linear(self.gpt2.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask, labels=None):
        # Get hidden states from GPT-2
        outputs = self.gpt2(input_ids=input_ids, attention_mask=attention_mask)
        # Use the hidden state of the last token for classification
        last_token_indices = attention_mask.sum(dim=1) - 1
        pooled_output = outputs.last_hidden_state[torch.arange(input_ids.size(0)), last_token_indices]
        logits = self.classifier(pooled_output)
        loss = None
        if labels is not None:
            # Compute loss
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.classifier.out_features), labels.view(-1))
        return {'loss': loss, 'logits': logits}

In [7]:
# Load GPT-2
gpt2 = transformers.GPT2Model.from_pretrained('gpt2', device_map='auto', torch_dtype='auto')

# Initialize the model
num_labels = 2  # Binary classification
model = GPT2ForClassification(gpt2, num_labels)

In [8]:
# Define training arguments to train for only 1 epoch
training_args = transformers.TrainingArguments(
    output_dir='./results',
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    eval_strategy="epoch",
    logging_dir='./logs',
    logging_steps=10,
)

In [9]:
# Function to compute evaluation metrics
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = torch.argmax(torch.from_numpy(logits), dim=-1).numpy()
    accuracy = sklearn.metrics.accuracy_score(labels, predictions)
    return {'accuracy': accuracy}

In [10]:
# Initialize the Trainer
trainer = transformers.Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

In [11]:
trainer.train()

  0%|          | 0/625 [00:00<?, ?it/s]

{'loss': 0.9479, 'grad_norm': 159.72096252441406, 'learning_rate': 4.92e-05, 'epoch': 0.02}
{'loss': 0.76, 'grad_norm': 28.952970504760742, 'learning_rate': 4.8400000000000004e-05, 'epoch': 0.03}
{'loss': 0.6734, 'grad_norm': 72.84471893310547, 'learning_rate': 4.76e-05, 'epoch': 0.05}
{'loss': 0.6677, 'grad_norm': 12.976938247680664, 'learning_rate': 4.6800000000000006e-05, 'epoch': 0.06}
{'loss': 0.6209, 'grad_norm': 31.5085391998291, 'learning_rate': 4.600000000000001e-05, 'epoch': 0.08}
{'loss': 0.6092, 'grad_norm': 27.92189598083496, 'learning_rate': 4.52e-05, 'epoch': 0.1}
{'loss': 0.6358, 'grad_norm': 11.485227584838867, 'learning_rate': 4.44e-05, 'epoch': 0.11}
{'loss': 0.7262, 'grad_norm': 82.11431884765625, 'learning_rate': 4.36e-05, 'epoch': 0.13}
{'loss': 0.626, 'grad_norm': 68.08614349365234, 'learning_rate': 4.2800000000000004e-05, 'epoch': 0.14}
{'loss': 0.6629, 'grad_norm': 31.249441146850586, 'learning_rate': 4.2e-05, 'epoch': 0.16}
{'loss': 0.6413, 'grad_norm': 27.914

  0%|          | 0/157 [00:00<?, ?it/s]

{'eval_loss': 0.6224797964096069, 'eval_accuracy': 0.7024, 'eval_runtime': 12.3521, 'eval_samples_per_second': 202.395, 'eval_steps_per_second': 12.71, 'epoch': 1.0}
{'train_runtime': 62.5166, 'train_samples_per_second': 159.958, 'train_steps_per_second': 9.997, 'train_loss': 0.6528895858764648, 'epoch': 1.0}


TrainOutput(global_step=625, training_loss=0.6528895858764648, metrics={'train_runtime': 62.5166, 'train_samples_per_second': 159.958, 'train_steps_per_second': 9.997, 'total_flos': 0.0, 'train_loss': 0.6528895858764648, 'epoch': 1.0})

In [12]:
trainer.evaluate()

  0%|          | 0/157 [00:00<?, ?it/s]

{'eval_loss': 0.6224797964096069,
 'eval_accuracy': 0.7024,
 'eval_runtime': 12.2255,
 'eval_samples_per_second': 204.491,
 'eval_steps_per_second': 12.842,
 'epoch': 1.0}

In [13]:
# Use the trainer to make predictions on the evaluation dataset
predictions_output = trainer.predict(eval_dataset)
probabilities = torch.nn.functional.softmax(torch.tensor(predictions_output.predictions), dim=-1)

  0%|          | 0/157 [00:00<?, ?it/s]

In [14]:
predictions_output.predictions.argmax(axis=-1).nonzero()

(array([   1,    4,   20,   64,   76,   79,   88,   90,   97,  111,  115,
         121,  141,  183,  198,  209,  218,  219,  275,  287,  294,  308,
         336,  337,  373,  375,  380,  389,  394,  402,  408,  420,  423,
         442,  456,  496,  515,  523,  545,  555,  564,  576,  593,  612,
         618,  632,  646,  654,  666,  670,  711,  735,  736,  771,  797,
         806,  814,  839,  852,  855,  874,  937,  938,  951,  957,  959,
        1003, 1015, 1029, 1039, 1053, 1061, 1085, 1097, 1113, 1122, 1136,
        1144, 1159, 1167, 1176, 1188, 1196, 1207, 1222, 1250, 1253, 1296,
        1305, 1323, 1350, 1406, 1443, 1446, 1500, 1541, 1594, 1601, 1605,
        1640, 1651, 1652, 1696, 1704, 1714, 1729, 1758, 1784, 1824, 1835,
        1840, 1849, 1859, 1871, 1887, 1922, 1923, 1944, 1952, 1965, 1993,
        1997, 2032, 2052, 2115, 2125, 2145, 2161, 2163, 2207, 2221, 2237,
        2253, 2283, 2297, 2317, 2338, 2341, 2347, 2377, 2404, 2418, 2431,
        2437, 2439, 2458, 2462, 2471, 

In [15]:
predictions_output.label_ids.nonzero()

(array([   0,    6,    9,   13,   16,   24,   26,   36,   37,   40,   41,
          49,   55,   59,   61,   62,   72,   73,   79,   82,   87,   89,
          92,   94,   96,  104,  106,  107,  108,  112,  116,  117,  120,
         131,  132,  134,  142,  143,  149,  154,  155,  164,  168,  169,
         171,  173,  186,  187,  188,  189,  199,  201,  204,  205,  211,
         219,  222,  224,  226,  230,  231,  232,  234,  239,  240,  246,
         247,  257,  258,  264,  274,  275,  277,  278,  281,  282,  289,
         290,  291,  297,  298,  301,  303,  306,  307,  310,  311,  315,
         317,  323,  324,  325,  329,  332,  336,  337,  338,  341,  350,
         352,  354,  356,  357,  361,  366,  367,  369,  370,  371,  377,
         380,  381,  382,  383,  392,  402,  403,  408,  410,  419,  420,
         422,  429,  442,  443,  445,  451,  452,  453,  457,  458,  460,
         461,  465,  467,  470,  475,  485,  487,  488,  493,  507,  509,
         518,  520,  523,  528,  529, 

In [19]:
predictions_output.predictions.argmax(axis=-1).sum()

np.int64(149)

In [18]:
torch.tensor(dataset['test']['label']).sum()

tensor(673)

In [21]:
top_idxs = reversed(probabilities[:, 1].sort().indices)[:10]
for i in top_idxs:
    print(f"Predicted probability: {probabilities[i,1]}")
    print(f"True label: {predictions_output.label_ids[i]}")
    print(f"Text: {dataset['test'][i.item()]['text']}")
    print("-" * 80)

Predicted probability: 0.7796695232391357
True label: 0
Text: The American Humane Association, which is the source of the familiar disclaimer "No animals were harmed..." (the registered trademark of the AHA), began to monitor the use of animals in film production more than 60 years ago, after a blindfolded horse was forced to leap to its death from the top of a cliff for a shot in the film Jesse James (1939). Needless to say, the atrocious act kills the whole entertainment aspect of this film for me. I suppose one could say that at least the horse didn't die in vain, since it was the beginning of the public waking up to the callous and horrendous pain caused animals for the glory of movie making, but I can't help but feel that if the poor animal had a choice, this sure wouldn't have been the path he would have taken!
--------------------------------------------------------------------------------
Predicted probability: 0.7515990734100342
True label: 0
Text: The story concerns a genealo