In [1]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import os
import random
import logging

# mkdirs('./logs')
logging.basicConfig(filename=os.path.join('./logs', 'longformer.log'),
                    format='%(asctime)s %(levelname)-8s %(message)s',
                    datefmt='%m-%d %H:%M', level=logging.DEBUG, filemode='w')

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# training parameters
model_config = {}

model_config["train_size"] = 500
model_config["val_size"] = 100
model_config["test_size"] = 100

# model_config['lr'] = 1e-4
model_config['window_size'] = 64
model_config['batch_size'] = 2
model_config['max_len'] = 4096
model_config["datapath"] = "./Long-document-dataset"
model_config["weight_path"] = "./weight"
model_config["num_epoch"] = 5
model_config["model_weight_path"] = None
model_config["longformer_lr"] = 1e-6
model_config["linear_lr"] = 1e-4
model_config["gamma"] = 0.8
device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')

logger.info(model_config)

# create custom dataset class for arvix classification dataset
class ArvixDataset(Dataset):
    def __init__(self, path, tokenizer, model_config, mode='train', max_len=4096):

        self.dictCls2Idx = {
            "cs.AI": 0,
            "cs.cv": 1,
            "cs.IT": 2,
            "cs.PL": 3,
            "math.AC": 4,
            "math.ST": 5,
            "cs.CE": 6, 
            "cs.DS": 7,
            "cs.NE": 8,
            "cs.SY": 9 , 
            "math.GR": 10
        }
        self.Idx2dictCls = {}
        self.dataset = []
        self.labels  = []
        self.tokenizer = tokenizer
        self.max_len = max_len

        for sub in self.dictCls2Idx:
            label_index = self.dictCls2Idx[sub]
            subfolder = os.path.join(path,sub)
            self.Idx2dictCls[label_index] = sub

            files = sorted([f for f in os.listdir(subfolder) if os.path.isfile(os.path.join(subfolder,f))])
            random.seed(1234)
            random.shuffle(files)

            if mode == "train":
                file_index = [i for i in range(model_config["train_size"])]
            elif mode == "validation":
                file_index = [i for i in range(model_config["train_size"], model_config["train_size"] + model_config["val_size"])]
            elif mode == "test":
                file_index = [i for i in range(model_config["train_size"] + model_config["val_size"], model_config["train_size"] + model_config["val_size"] + model_config["test_size"])]

            for i in file_index:
                f = files[i]
                fname = os.path.join(subfolder,f)
                self.dataset.append(fname)
                self.labels.append(label_index)
        

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        label = self.labels[idx]
        data_path = self.dataset[idx]
        data = self.read_txt(data_path)
        encoded_data = self.tokenizer.encode(data, truncation=True, padding="max_length", max_length=self.max_len)
        att_mask = torch.ones(len(encoded_data), dtype=torch.long)
        att_mask[0] = 2
        sample = {"Text": torch.tensor(encoded_data), 
                  "Attention": att_mask, 
                  "Label": torch.Tensor([label])}
        return sample

    def read_txt(self, file_path):
        with open(file_path, 'r') as file:
            text = file.read().replace('\n', '')
        return text

In [2]:
import torch
from longformer.longformer import Longformer, LongformerConfig
from longformer.sliding_chunks import pad_to_window_size
import requests
import tarfile
from tqdm import tqdm
import numpy as np
import torch
from transformers import RobertaForMaskedLM, RobertaTokenizerFast
import time

In [3]:
config = LongformerConfig.from_pretrained('longformer-base-4096/') 
config.attention_mode = 'sliding_chunks'
config.attention_window = [model_config['window_size']] * 12

In [4]:
class LongformerClassifier(torch.nn.Module):
    
    def __init__(self, in_features=768, out_features=11):
        super(LongformerClassifier, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.longformer = Longformer.from_pretrained('longformer-base-4096/', config=config)
        self.linear = torch.nn.Linear(in_features=in_features, out_features=out_features)
    
    def forward(self, input_ids, attention_mask):
        x = self.longformer(input_ids=input_ids, attention_mask=attention_mask)[0]
        x = self.linear(x[:, 0])
        return x

In [5]:
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base', model_max_length=model_config["max_len"])
train_dataset = ArvixDataset(model_config["datapath"], tokenizer, model_config, mode="train", max_len=model_config["max_len"])
val_dataset = ArvixDataset(model_config["datapath"], tokenizer, model_config, mode="validation", max_len=model_config["max_len"])
test_dataset = ArvixDataset(model_config["datapath"], tokenizer, model_config, mode="test", max_len=model_config["max_len"])

In [6]:
train_dataloader = DataLoader(train_dataset, batch_size=model_config['batch_size'], shuffle=True, collate_fn=None)
val_dataloader = DataLoader(val_dataset, batch_size=model_config['batch_size'], shuffle=False, collate_fn=None)
data = next(iter(train_dataloader))

In [7]:
model = LongformerClassifier().to(device)

if model_config["model_weight_path"] is not None:
    file_name = os.path.join(model_config["weight_path"], model_config["model_weight_path"])
    model = torch.load(file_name).to(device)

#optimizer = torch.optim.AdamW(model.parameters(), lr = model_config["lr"])
optimizer = torch.optim.AdamW([
    {'params': model.longformer.parameters(), 'lr': model_config["longformer_lr"]},
    {'params': model.linear.parameters(), 'lr': model_config["linear_lr"]}])

loss_fn = torch.nn.CrossEntropyLoss().to(device)

scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=model_config["gamma"], last_epoch=-1)

In [8]:
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [9]:
for epoch in tqdm(range(model_config["num_epoch"])):
    logger.info("in epoch:" + str(round))
    total_train_loss = 0
    model.train()
    current_lr = scheduler.get_last_lr()
    print(f"Current Learning rate for longformer: {current_lr[0]}, for linear layer: {current_lr[1]}")
    for step, data in enumerate(train_dataloader):
        start=time.time()
        input_ids = data["Text"].to(device)
        attention_mask = data["Attention"].to(device)
        label = data["Label"].to(device)
        optimizer.zero_grad()  

        outputs = model(input_ids, attention_mask=attention_mask)

        loss = loss_fn(outputs, label.squeeze(1).long())
        total_train_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        end=time.time()
        logger.info("Epoch: " + str(epoch)+ " Step: " + str(step)+ " Loss: " + str(loss.item()) + " Time: " + str(end-start))
        
        if(step % 10 == 0):
            print(f"Loss after {step} step: {loss}")

    scheduler.step()
    
    avg_train_loss = total_train_loss / len(train_dataloader)   
    print("Average training loss: {0:.2f}".format(avg_train_loss))

    # save model weight
    print("Saving model weight...")

    if not os.path.exists(model_config['weight_path']):
        os.makedirs(model_config['weight_path'])

    weight_file_name = f"{model_config['weight_path']}/e{epoch}_model.pt"
    torch.save(model.state_dict(), weight_file_name)
        
    print("")
    print("Running Validation...")

    # Put the model in evaluation mode-
    model.eval()

    # Tracking variables 
    total_eval_accuracy = 0
    total_eval_loss = 0
    nb_eval_steps = 0

        # Evaluate data for one epoch
    for step, data in enumerate(val_dataloader):
        
        input_ids = data["Text"].to(device)
        attention_mask = data["Attention"].to(device)
        label = data["Label"].to(device)

        with torch.no_grad():        
            outputs = model(input_ids, attention_mask=attention_mask)
            
        # Accumulate the validation loss.
        loss = loss_fn(outputs, label.squeeze(1).long())
        total_eval_loss += loss.item()

        # Move logits and labels to CPU
        logits = outputs.detach().cpu().numpy()
        label_ids = label.to('cpu').numpy()

        # Calculate the accuracy for this batch of test sentences, and
        # accumulate it over all batches.
        total_eval_accuracy += flat_accuracy(logits, label_ids)

    # Report the final accuracy for this validation run.
    avg_val_accuracy = total_eval_accuracy / len(val_dataloader)
    print("  Accuracy: {0:.2f}".format(avg_val_accuracy))

    # Calculate the average loss over all of the batches.
    avg_val_loss = total_eval_loss / len(val_dataloader)
    
    print("  Validation Loss: {0:.2f}".format(avg_val_loss))
    logger.info("Epoch: " + str(epoch) + "Accuracy: {0:.2f}".format(avg_val_accuracy))
    logger.info("Epoch: " + str(epoch) + "Validation Loss: {0:.2f}".format(avg_val_loss))

print("")
print("Training complete!")

  0%|          | 0/5 [00:00<?, ?it/s]

Current Learning rate for longformer: 1e-06, for linear layer: 0.0001
Loss after 0 step: 2.3883461952209473
Loss after 10 step: 2.5824050903320312
Loss after 20 step: 2.7779407501220703


  0%|          | 0/5 [00:46<?, ?it/s]


KeyboardInterrupt: 

In [None]:
import pandas as pd
import numpy as np

class Calculator:
    def __init__(self, num_class=11):
        self.num_class = num_class
        self.dictIdx2Cls = {
            0: "cs.AI",
            1: "cs.cv",
            2: "cs.IT",
            3: "cs.PL",
            4: "math.AC",
            5: "math.ST",
            6: "cs.CE", 
            7: "cs.DS",
            8: "cs.NE",
            9: "cs.SY", 
            10: "math.GR"
        }

    def init_metrics(self):
        class_list = [i for i in range(self.num_class)]
        val_list = [0] * self.num_class

        self.TP = dict(zip(class_list, val_list))
        self.positive_pred = dict(zip(class_list, val_list))
        self.positive_label = dict(zip(class_list, val_list))

        self.precision = dict(zip(class_list, val_list))
        self.recall = dict(zip(class_list, val_list))
        self.f1 = dict(zip(class_list, val_list))

    def update_result(self, preds, labels):
        preds_flat = np.argmax(preds, axis=1).flatten()
        labels_flat = labels.flatten()

        for i in range(self.num_class):

            this_pred = np.array([1 if pred == i else 0 for pred in preds_flat])
            this_label = np.array([1 if label == i else 0 for label in labels_flat])

            self.TP[i] += np.sum(this_pred * this_label)
            self.positive_pred[i] += np.sum(this_pred)
            self.positive_label[i] += np.sum(this_label)

    def get_overall_performance(self):

        precision = sum(self.TP.values()) / sum(self.positive_pred.values())
        recall = sum(self.TP.values()) / sum(self.positive_label.values())
        f1 = (2 * sum(np.array(list(result_calculator.precision.values())) * np.array(list(result_calculator.recall.values())))) / (sum(self.precision.values()) + sum(self.recall.values()))
        # accuracy = sum(self.correct.values()) / sum(self.total.values())
        total = sum(self.positive_label.values())

        return ["overall", total, precision, recall, f1]

    def get_metrics(self):

        for i in range(self.num_class):

            self.precision[i] = (self.TP[i] / self.positive_pred[i]) if self.positive_pred[i] else 0
            self.recall[i] = (self.TP[i] / self.positive_label[i]) if self.positive_label[i] else 0
            self.f1[i] = (2.0 * self.precision[i] * self.recall[i] / (self.precision[i] + self.recall[i])) if (self.precision[i] + self.recall[i]) else 0
            # self.accuracy[i] = self.correct[i] / self.total[i] if self.total[i] else 0
     
        result_dict = {
            "Class": self.dictIdx2Cls.values(),
            "Sample Size": self.positive_label.values(),
            # "Accuracy": self.accuracy.values(),
            "Precision": self.precision.values(),
            "Recall": self.recall.values(),
            "F1": self.f1.values()
        }

        result_df = pd.DataFrame(result_dict)
        result_df.loc[len(result_df.index)] = self.get_overall_performance()

        return result_df

In [None]:
model.load_state_dict(torch.load("no_pretrain_weight/e8_model.pt"))
print(f"Load model weight from file")

In [None]:
test_dataset = ArvixDataset(model_config["datapath"], tokenizer, model_config, mode="test", max_len=model_config["max_len"])
test_dataloader = DataLoader(val_dataset, batch_size=model_config['batch_size'], shuffle=False, collate_fn=None)

In [None]:
total_test_accuracy = 0

result_calculator = Calculator(num_class=11)
result_calculator.init_metrics()

model.eval()

for step, data in enumerate(tqdm(test_dataloader)):
    start = time.time()
    input_ids = data["Text"].to(device)
    attention_mask = data["Attention"].to(device)
    label = data["Label"].to(device)

    with torch.no_grad():        
        outputs = model(input_ids, attention_mask = attention_mask)
        
    # Accumulate the validation loss.
    # loss = loss_fn(outputs, label.squeeze(1).long())
    # total_eval_loss += loss.item()

    # Move logits and labels to CPU
    logits = outputs.detach().cpu().numpy()
    label_ids = label.to('cpu').numpy()

    # Calculate the metrics for this batch of test sentences, and
    # accumulate it over all batches.
    result_calculator.update_result(logits, label_ids)

    # Calculate the accuracy for this batch of test sentences, and
    # accumulate it over all batches.
    total_test_accuracy += flat_accuracy(logits, label_ids)
    end = time.time()
    print(end-start)

# Report the final accuracy for this validation run.
avg_test_accuracy = total_test_accuracy / len(test_dataloader)
print("")
print("Test  Accuracy: {0:.3f}".format(avg_val_accuracy))

# Report the final metrics for this test run.
result_df = result_calculator.get_metrics()