Note by Bin:

AMP (automatic mixed precision training) is used to make the best use of the GPU memory. 

**It is not tested if the code works on CPU.**

In [1]:
import numpy as np
import pandas as pd
import os
import csv
import math
import torch
import random
from tqdm.auto import tqdm
from copy import deepcopy
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM

# Read the dataset

In [2]:
# Dataset from https://www.kaggle.com/datasets/czzzzzzz/scp1to7
scp_df = pd.read_csv('/kaggle/input/scp1to7/scp6999.csv', header=0, delimiter=',', quoting=csv.QUOTE_ALL, encoding='utf-8', index_col=False, usecols=['code', 'title', 'text', 'image captions', 'rating', 'state', 'tags',  'link'])

In [3]:
scp_df

Unnamed: 0,code,title,text,image captions,rating,state,tags,link
0,SCP-001,"""Awaiting De-classification [Blocked]""","""GENERAL NOTICE 001-Alpha: In order to prevent...",,,blocked,_cc _licensebox hub,https://scp-wiki.wikidot.com/scp-001
1,SCP-002,"""The ""Living"" Room""","""Item #: SCP-002 \n Object Class: Euclid \n Sp...","""SCP-002 in its containment area""",1702.0,active,_cc _licensebox alive euclid featured scp stru...,https://scp-wiki.wikidot.com/scp-002
2,SCP-003,"""Biological Motherboard""","""Item #: SCP-003 \n Object Class: Euclid \n Sp...","""A close up of SCP-003's circuitry""",765.0,active,_cc _licensebox alive biological computer dire...,https://scp-wiki.wikidot.com/scp-003
3,SCP-004,"""The 12 Rusty Keys and the Door""","""Item #: SCP-004 \n Object Class: Euclid \n Sp...","""SCP-004-1""",1096.0,active,_cc _licensebox euclid key mind-affecting port...,https://scp-wiki.wikidot.com/scp-004
4,SCP-005,"""Skeleton Key""","""Item #: SCP-005 \n Object Class: Safe \n Spec...","""A close up of SCP-005""",645.0,active,_cc _licensebox adaptive key metallic safe scp...,https://scp-wiki.wikidot.com/scp-005
...,...,...,...,...,...,...,...,...
6994,SCP-6995,"""Cannabincognito""","""Item #: SCP-6995 \n Object Class: Euclid \n S...",,54.0,active,6000 _licensebox alive antimemetic euclid mind...,https://scp-wiki.wikidot.com/scp-6995
6995,SCP-6996,"""Does the Red Moon Howl?""","""Check out my other pages on my author profile...","""SCP-6996. \n Image captured by surveillance e...",333.0,active,6000 _cc _licensebox concept empathic extradim...,https://scp-wiki.wikidot.com/scp-6996
6996,SCP-6997,"""De Rerum Natura""","""Item #: SCP-6997 \n Object Class: Safe \n Spe...","""The location of SCP-6997.""",96.0,active,6000 _cc _licensebox antimemetic engraved hall...,https://scp-wiki.wikidot.com/scp-6997
6997,SCP-6998,"""SCP Author Cerastes's Untitled SCP-6000 Conte...","""Item #: SCP-6998 \n Object Class: N/A \n Spec...",,131.0,active,6000 _listpages alive esoteric-class humanoid ...,https://scp-wiki.wikidot.com/scp-6998


In [4]:
unified_prompt = "This is an SCP-Foundation fiction:\nTitle:"

In [5]:
def set_seed(seed):
    # Seed function by Saurav Maheshkar
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

random_seed = 42
set_seed(random_seed)

Random seed set as 42


# Preprocessing

In [6]:
def remove_after(s, item_str, keep=True):
    if item_str in s:
        s = s.split(item_str)[0] + item_str if keep else s.split(item_str)[0]
    return s

def remove_before(s, item_str, keep=True):
    if item_str in s:
        s = item_str + s.split(item_str)[1] if keep else s.split(item_str)[1]
    return s

In [7]:
def preprocess(text, title):
    text = remove_before(text, "Item #:", keep=True)
    text = remove_after(text, "« SCP-", keep=False)
    text_train = unified_prompt + title[1:-1] + '\n' + text.strip("'").replace(' \n ', '\n')
    return text_train

In [8]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

Using Tesla P100-PCIE-16GB


In [9]:
print(preprocess(scp_df['text'][7], scp_df['title'][7]))

This is an SCP-Foundation fiction:
Title:Zombie Plague
Item #: SCP-008
Object Class: Euclid
Special Containment Procedures: SCP-008 samples have been deemed Class V extreme biological hazards, and all related protocols apply. Incineration and irradiation measures will be deployed in the event of political or military action which may result in the facility being dismantled; a power failure; or zero communications from operatives or outside channels during any given eight hour period.
The quarantine period for operatives leaving the facility is four months. If a breach has occurred, incineration and irradiation measures shall be deployed. It should be the policy of all G2 sites to not prepare an evacuation procedure.
Description: SCP-008 is a complex prion, samples of which are stored in each of the known G2 sites. Research into SCP-008 is highly classified and primarily aimed at preventing research which may lead to the synthesis of SCP-008 in the distant future. Traits of the SCP-008 

In [10]:
text_list = scp_df['text'].tolist()
title_list = scp_df['title'].tolist()
whole_list = [preprocess(text, title) for text, title in zip(text_list, title_list)]
random.shuffle(whole_list)
train_list = whole_list[:6000]
valid_list = whole_list[6000:]

In [11]:
class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=256):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        inputs = self.tokenizer("summarize: " + text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        input_ids = inputs['input_ids'].squeeze()
        attention_mask = inputs['attention_mask'].squeeze()
        return {'input_ids': input_ids, 'attention_mask': attention_mask}

In [12]:
def calculate_perplexity(dataloader, model, device):
    model.eval()
    total_loss = 0
    total_count = 0

    for batch in tqdm(dataloader, desc="Computeing perplexity"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = input_ids.clone()
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_loss += loss.item() * input_ids.size(0)
            total_count += input_ids.size(0)

    perplexity = torch.exp(torch.tensor(total_loss / total_count)).item()
    return perplexity

# Training

In [13]:
model_name = "gpt2-medium"

tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set padding token (GPT2 does not have a default one)
tokenizer.pad_token = tokenizer.eos_token

train_dataset = TextDataset(train_list, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_dataset = TextDataset(valid_list, tokenizer)
valid_dataloader = DataLoader(valid_dataset, batch_size=8, shuffle=False)

model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
optimizer = AdamW(model.parameters(), lr=1e-4)
scheduler = StepLR(optimizer, step_size=1, gamma=0.1)

epochs = 2
lowest_perplexity = float('inf')
best_model = None

scaler = GradScaler()

model.train()

valid_perplexity = calculate_perplexity(valid_dataloader, model, device)
print(f"Validation perplexity before training: {valid_perplexity}")

for epoch in range(epochs):
    for batch in tqdm(train_dataloader, desc=f"Training epoch {epoch+1}"):
        optimizer.zero_grad()

        with autocast():
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = input_ids.clone()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    scheduler.step()

    scheduler.step()
    
    # model.eval() is done in the calculate_perplexity() method
    valid_perplexity = calculate_perplexity(valid_dataloader, model, device)
    print(f"Validation perplexity after epoch {epoch+1}: {valid_perplexity}")
    if valid_perplexity < lowest_perplexity:
        lowest_perplexity = valid_perplexity
        best_model = deepcopy(model)
        print("Best_model updated")

Downloading (…)lve/main/config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Computeing perplexity:   0%|          | 0/125 [00:00<?, ?it/s]

Validation perplexity before training: 22.89133644104004


Training epoch 1:   0%|          | 0/750 [00:00<?, ?it/s]

Computeing perplexity:   0%|          | 0/125 [00:00<?, ?it/s]

Validation perplexity after epoch 1: 8.773285865783691
Best_model updated


Training epoch 2:   0%|          | 0/750 [00:00<?, ?it/s]

Computeing perplexity:   0%|          | 0/125 [00:00<?, ?it/s]

Validation perplexity after epoch 2: 8.858723640441895


In [None]:
torch.save(best_model, "/kaggle/working/version4.pt")