In [1]:
import pandas as pd
import numpy as np
import pickle
import matplotlib.pyplot as plt
import re
import copy
from tqdm.notebook import tqdm
import time
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import accuracy_score, f1_score

from transformers import T5Tokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup

from sklearn.metrics import confusion_matrix, accuracy_score
import matplotlib.pyplot as plt
import numpy as np

train_df = pd.read_csv("kaggle_toxicity_data_cleaned.csv")
train_df = train_df.drop(columns = ["any_toxic"])
train_df['text_cleaned'].fillna('', inplace=True)
train_df.head()

Unnamed: 0,id,comments,toxic,severe_toxic,obscene,threat,insult,identity_hate,text_cleaned
0,0000997932d777bf,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0,explanation edits made username hardcore metal...
1,000103f0d9cfb60f,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0,daww match background colour im seemingly stuc...
2,000113f07ec002fd,"Hey man, I'm really not trying to edit war. It...",0,0,0,0,0,0,hey man im really trying edit war guy constant...
3,0001b41b1c6bb37e,"""\nMore\nI can't make any real suggestions on ...",0,0,0,0,0,0,cant make real suggestion improvement wondered...
4,0001d958c54c6e35,"You, sir, are my hero. Any chance you remember...",0,0,0,0,0,0,sir hero chance remember page thats


In [3]:
def get_texts(df):
    texts = 'multilabel classification: ' + df['text_cleaned']
    texts = texts.values.tolist()
    return texts

def get_labels(df):
    # Get the label column and convert it to multi-class integer labels
    labels = df[['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']].values
    # Combine each row label into a unique integer
    integer_labels = [int("".join(map(str, row)), 2) for row in labels]
    return integer_labels

texts = get_texts(train_df)
labels = get_labels(train_df)

for text, label in zip(texts[:5], labels[:5]):
    print(f'TEXT -\t{text}')
    print(f'LABEL -\t{label}')
    print()

TEXT -	multilabel classification: explanation edits made username hardcore metallica fan reverted werent vandalism closure gas voted new york doll fac please dont remove template talk page since im retired 
LABEL -	0

TEXT -	multilabel classification: daww match background colour im seemingly stuck thanks talk  january   utc
LABEL -	0

TEXT -	multilabel classification: hey man im really trying edit war guy constantly removing relevant information talking edits instead talk page seems care formatting actual info
LABEL -	0

TEXT -	multilabel classification: cant make real suggestion improvement wondered section statistic later subsection type accident think reference may need tidying exact format ie date format etc later noone else first preference formatting style reference want please let know appears backlog article review guess may delay reviewer turn listed relevant form eg wikipediagoodarticlenominationstransport
LABEL -	0

TEXT -	multilabel classification: sir hero chance remember

In [4]:
class Config:
    def __init__(self):
        self.SEED = 42
        self.MODEL_PATH = 't5-small'
        self.TOKENIZER = T5Tokenizer.from_pretrained(self.MODEL_PATH)
        self.SRC_MAX_LENGTH = 250
        self.TGT_MAX_LENGTH = 10
        self.BATCH_SIZE = 32  # Reduced batch size to save time and memory
        self.VALIDATION_SPLIT = 0.25  # Reduced validation split for faster training

        self.DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.FULL_FINETUNING = True
        self.LR = 3e-5
        self.OPTIMIZER = 'AdamW'
        self.CRITERION = 'CrossEntropyLoss'
        self.SAVE_BEST_ONLY = True
        self.N_VALIDATE_DUR_TRAIN = 3
        self.EPOCHS = 3

config = Config()



In [5]:
class T5Dataset(Dataset):
    def __init__(self, df, indices):
        df = df.iloc[indices]
        self.texts = get_texts(df)
        self.labels = get_labels(df)
        self.tokenizer = config.TOKENIZER
        self.src_max_length = config.SRC_MAX_LENGTH

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, index):
        src_tokenized = self.tokenizer.encode_plus(
            self.texts[index],
            max_length=self.src_max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        
        label = torch.tensor(self.labels[index], dtype=torch.long)

        return {
            'src_input_ids': src_tokenized['input_ids'].squeeze().long(),
            'src_attention_mask': src_tokenized['attention_mask'].squeeze().long(),
            'labels': label
        }


np.random.seed(config.SEED)
dataset_size = len(train_df)
indices = list(range(dataset_size))
split = int(np.floor(config.VALIDATION_SPLIT * dataset_size))
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
train_data = T5Dataset(train_df, train_indices)
val_data = T5Dataset(train_df, val_indices)
train_dataloader = DataLoader(train_data, batch_size=config.BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=config.BATCH_SIZE)

In [6]:
class T5Model(nn.Module):
    def __init__(self):
        super(T5Model, self).__init__()
        self.t5_model = T5ForConditionalGeneration.from_pretrained(config.MODEL_PATH)
        # Replace the last layer with a linear layer and output 6 categories
        self.classifier = nn.Linear(self.t5_model.config.d_model, 64)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.t5_model.encoder(input_ids=input_ids, attention_mask=attention_mask)
        # Take the output of the encoder as the input of the classifier
        encoder_output = outputs.last_hidden_state.mean(dim=1)  
        logits = self.classifier(encoder_output)

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, 64), labels.view(-1))
            return loss, logits
        return logits


In [7]:
def train(model, train_dataloader, optimizer, scheduler, epoch):
    model.train()
    start_time = time.time()
    train_loss = 0
    all_labels = []
    all_preds = []

    for step, batch in enumerate(tqdm(train_dataloader, desc='Training')):
        
        batch_start_time = time.time()
        
        input_ids = batch['src_input_ids'].to(config.DEVICE)
        attention_mask = batch['src_attention_mask'].to(config.DEVICE)
        labels = batch['labels'].to(config.DEVICE)

        optimizer.zero_grad()
        loss, logits = model(input_ids, attention_mask=attention_mask, labels=labels)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()
        scheduler.step()

        # Collect predictions for accuracy calculation
        preds = torch.argmax(logits, dim=-1).cpu().numpy()
        labels = labels.cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels)

        # 打印step日志，包括step时间
        print(f'Step {step+1}/{len(train_dataloader)} | Loss: {loss.item()} | Step Time: {time.time() - batch_start_time:.2f}s')

    avg_train_loss = train_loss / len(train_dataloader)
    train_accuracy = accuracy_score(all_labels, all_preds)  # Calculate training accuracy
    print(f'Epoch {epoch+1} | Training Loss: {avg_train_loss} | Training Accuracy: {train_accuracy} | Epoch Time: {time.time() - start_time:.2f}s')
    return avg_train_loss, train_accuracy


In [8]:
def validate(model, val_dataloader):
    model.eval()
    val_loss = 0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for batch in val_dataloader:
            input_ids = batch['src_input_ids'].to(config.DEVICE)
            attention_mask = batch['src_attention_mask'].to(config.DEVICE)
            labels = batch['labels'].to(config.DEVICE)

            loss, logits = model(input_ids, attention_mask=attention_mask, labels=labels)
            val_loss += loss.item()

            preds = torch.argmax(logits, dim=-1).cpu().numpy()
            labels = labels.cpu().numpy()
            
            all_preds.extend(preds)
            all_labels.extend(labels)

    avg_val_loss = val_loss / len(val_dataloader)
    val_accuracy = accuracy_score(all_labels, all_preds)
    return avg_val_loss, val_accuracy, all_labels, all_preds


In [None]:
import time

def run():
    model = T5Model().to(config.DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=config.LR)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * config.EPOCHS
    )

    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    epoch_times = []  

    for epoch in range(config.EPOCHS):
        
        epoch_start_time = time.time()
        
        # Training
        avg_train_loss, train_accuracy = train(model, train_dataloader, optimizer, scheduler, epoch)
        train_losses.append(avg_train_loss)
        train_accuracies.append(train_accuracy)
        
        # Validation
        avg_val_loss, val_accuracy, all_labels, all_preds = validate(model, val_dataloader)
        val_losses.append(avg_val_loss)
        val_accuracies.append(val_accuracy)
        
        
        epoch_time = time.time() - epoch_start_time
        epoch_times.append(epoch_time)

        print(f"Epoch {epoch+1} | Training Time: {epoch_time:.2f}s | Validation Loss: {avg_val_loss:.4f} | Validation Accuracy: {val_accuracy:.4f}")

    # Plot and Save Training and Validation Loss Curve
    plt.figure(figsize=(8, 6))
    plt.plot(range(1, config.EPOCHS + 1), train_losses, label='Training Loss', marker='o')
    plt.plot(range(1, config.EPOCHS + 1), val_losses, label='Validation Loss', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss Curve')
    plt.legend()
    plt.grid(True)
    plt.savefig('training_validation_loss_curve.png', format='png')
    plt.show()

    # Plot and Save Training and Validation Accuracy Curve
    plt.figure(figsize=(8, 6))
    plt.plot(range(1, config.EPOCHS + 1), train_accuracies, label='Training Accuracy', marker='o')
    plt.plot(range(1, config.EPOCHS + 1), val_accuracies, label='Validation Accuracy', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy Curve')
    plt.legend()
    plt.grid(True)
    plt.savefig('training_validation_accuracy_curve.png', format='png')
    plt.show()

    # Plot and Save Training Time Curve
    plt.figure(figsize=(8, 6))
    plt.plot(range(1, config.EPOCHS + 1), epoch_times, label='Training Time (s)', marker='o', color='purple')
    plt.xlabel('Epoch')
    plt.ylabel('Time (seconds)')
    plt.title('Training Time per Epoch')
    plt.legend()
    plt.grid(True)
    plt.savefig('training_time_per_epoch.png', format='png')
    plt.show()

    # Process and Save Confusion Matrix
    filtered_labels = [label for label in all_labels if label != -100]
    filtered_preds = [pred for label, pred in zip(all_labels, all_preds) if label != -100]

    conf_matrix = confusion_matrix(filtered_labels, filtered_preds)

    # Plot confusion matrix with labels
    plt.figure(figsize=(8, 6))
    plt.imshow(conf_matrix, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    tick_marks = np.arange(len(set(filtered_labels)))
    plt.xticks(tick_marks, tick_marks, rotation=45)
    plt.yticks(tick_marks, tick_marks)
    plt.xlabel('Predicted label')
    plt.ylabel('True label')

    # Add text labels to each cell in the confusion matrix
    thresh = conf_matrix.max() / 2
    for i in range(conf_matrix.shape[0]):
        for j in range(conf_matrix.shape[1]):
            plt.text(j, i, format(conf_matrix[i, j], 'd'),
                     ha="center", va="center",
                     color="white" if conf_matrix[i, j] > thresh else "black")

    plt.savefig('confusion_matrix.png', format='png')
    plt.show()

    # Save the trained model
    model_save_path = 't5_update_model.pth'
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")

    return model

best_model = run()


  return torch.load(checkpoint_file, map_location="cpu")


Training:   0%|          | 0/3740 [00:00<?, ?it/s]

Step 1/3740 | Loss: 4.134326457977295 | Step Time: 0.82s
Step 2/3740 | Loss: 4.129802227020264 | Step Time: 0.17s
Step 3/3740 | Loss: 4.098855495452881 | Step Time: 0.18s
Step 4/3740 | Loss: 4.095735549926758 | Step Time: 0.18s
Step 5/3740 | Loss: 4.0767292976379395 | Step Time: 0.17s
Step 6/3740 | Loss: 4.0529704093933105 | Step Time: 0.17s
Step 7/3740 | Loss: 4.049806118011475 | Step Time: 0.18s
Step 8/3740 | Loss: 4.061184406280518 | Step Time: 0.19s
Step 9/3740 | Loss: 4.033051013946533 | Step Time: 0.19s
Step 10/3740 | Loss: 4.025968074798584 | Step Time: 0.18s
Step 11/3740 | Loss: 4.014590740203857 | Step Time: 0.17s
Step 12/3740 | Loss: 3.9837303161621094 | Step Time: 0.17s
Step 13/3740 | Loss: 3.970710039138794 | Step Time: 0.17s
Step 14/3740 | Loss: 3.9543960094451904 | Step Time: 0.17s
Step 15/3740 | Loss: 3.9655356407165527 | Step Time: 0.17s
Step 16/3740 | Loss: 3.928330659866333 | Step Time: 0.17s
Step 17/3740 | Loss: 3.938983678817749 | Step Time: 0.17s
Step 18/3740 | Los

Training:   0%|          | 0/3740 [00:00<?, ?it/s]

Step 1/3740 | Loss: 0.3289826214313507 | Step Time: 0.17s
Step 2/3740 | Loss: 0.34666189551353455 | Step Time: 0.17s
Step 3/3740 | Loss: 0.6056380867958069 | Step Time: 0.17s
Step 4/3740 | Loss: 0.37166208028793335 | Step Time: 0.17s
Step 5/3740 | Loss: 0.08816030621528625 | Step Time: 0.17s
Step 6/3740 | Loss: 0.15795432031154633 | Step Time: 0.17s
Step 7/3740 | Loss: 0.4004672169685364 | Step Time: 0.17s
Step 8/3740 | Loss: 0.5138581991195679 | Step Time: 0.18s
Step 9/3740 | Loss: 0.2316731959581375 | Step Time: 0.17s
Step 10/3740 | Loss: 0.14502671360969543 | Step Time: 0.18s
Step 11/3740 | Loss: 0.7564706802368164 | Step Time: 0.17s
Step 12/3740 | Loss: 0.1959824562072754 | Step Time: 0.17s
Step 13/3740 | Loss: 0.337970107793808 | Step Time: 0.17s
Step 14/3740 | Loss: 0.17801663279533386 | Step Time: 0.17s
Step 15/3740 | Loss: 0.5636593699455261 | Step Time: 0.17s
Step 16/3740 | Loss: 0.10248098522424698 | Step Time: 0.17s
Step 17/3740 | Loss: 0.4554894268512726 | Step Time: 0.17s


Training:   0%|          | 0/3740 [00:00<?, ?it/s]

Step 1/3740 | Loss: 0.48475128412246704 | Step Time: 0.17s
Step 2/3740 | Loss: 0.14144445955753326 | Step Time: 0.17s
Step 3/3740 | Loss: 0.5906928777694702 | Step Time: 0.17s
Step 4/3740 | Loss: 0.44994688034057617 | Step Time: 0.17s
Step 5/3740 | Loss: 0.03282606601715088 | Step Time: 0.17s
Step 6/3740 | Loss: 0.5561456680297852 | Step Time: 0.17s
Step 7/3740 | Loss: 0.207791268825531 | Step Time: 0.17s
Step 8/3740 | Loss: 0.2850065231323242 | Step Time: 0.17s
Step 9/3740 | Loss: 0.10519100725650787 | Step Time: 0.17s
Step 10/3740 | Loss: 0.6358451247215271 | Step Time: 0.17s
Step 11/3740 | Loss: 0.49111708998680115 | Step Time: 0.17s
Step 12/3740 | Loss: 0.3451673090457916 | Step Time: 0.17s
Step 13/3740 | Loss: 0.5622207522392273 | Step Time: 0.19s
Step 14/3740 | Loss: 0.3241123557090759 | Step Time: 0.20s
Step 15/3740 | Loss: 0.16868720948696136 | Step Time: 0.20s
Step 16/3740 | Loss: 0.21726994216442108 | Step Time: 0.18s
Step 17/3740 | Loss: 0.2339804321527481 | Step Time: 0.18s