In [1]:
# %%bash
# git clone https://github.com/LiqunW/Long-document-dataset.git 
# cd Long-document-dataset/
# unrar x cs.AI.rar 
# unrar x cs.CE.rar
# unrar x cs.DS.rar
# unrar x cs.IT.rar
# unrar x cs.NE.rar
# unrar x cs.PL.rar
# unrar x cs.SY.rar
# unrar x cs.cv.rar
# unrar x math.AC.rar
# unrar x math.GR.rar
# unrar x math.ST.rar
# cd ..

In [1]:
!pip install git+https://github.com/allenai/longformer.git

Collecting git+https://github.com/allenai/longformer.git
  Cloning https://github.com/allenai/longformer.git to /tmp/pip-req-build-ykd6jyl6
  Running command git clone --filter=blob:none --quiet https://github.com/allenai/longformer.git /tmp/pip-req-build-ykd6jyl6
  Resolved https://github.com/allenai/longformer.git to commit caefee668e39cacdece7dd603a0bebf24df6d8ca
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting transformers@ git+http://github.com/ibeltagy/transformers.git@longformer_encoder_decoder#egg=transformers
  Cloning http://github.com/ibeltagy/transformers.git (to revision longformer_encoder_decoder) to /tmp/pip-install-lcimmogz/transformers_5955c852324242ddb2015aed26df31fe
  Running command git clone --filter=blob:none --quiet http://github.com/ibeltagy/transformers.git /tmp/pip-install-lcimmogz/transformers_5955c852324242ddb2015aed26df31fe
  Running command git checkout -b longformer_encoder_decoder --track origin/longformer_encoder_decoder
  Switched to a n

In [2]:
!pip install timm



In [1]:
from datasets import ArvixDataset
from model import SwinTransformer
from transformers import RobertaForMaskedLM, RobertaTokenizerFast,LongformerModel
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import os
import random
import logging
import requests
import tarfile
from transformers import RobertaConfig, RobertaModel
from tqdm import tqdm
import time
import numpy as np

In [2]:
logging.basicConfig(filename=os.path.join('./logs', 'swintransformer.log'),
                    format='%(asctime)s %(levelname)-8s %(message)s',
                    datefmt='%m-%d %H:%M', level=logging.INFO, filemode='w')

In [3]:
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# training parameters
model_config = {}
model_config["train_size"] = 2000
model_config["val_size"] = 200
model_config["test_size"] = 100

model_config['window_size'] = 128
model_config['batch_size'] = 2
model_config['max_len'] = 4096
model_config["weight_path"] = "./weight/"
model_config["num_epoch"] = 30
model_config["model_weight_path"] = None
# model_config["longformer_lr"] = 1e-6
# model_config["linear_lr"] = 1e-4
model_config["lr"] = 5e-6
model_config["gamma"] = 0.8
model_config["datapath"] = "./Long-document-dataset/"
device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
logger.info(model_config)

In [4]:
## dataloader
#download longformer-base-4096.tar.gz 
fname = 'longformer-base-4096.tar.gz'
url = 'https://ai2-s2-research.s3-us-west-2.amazonaws.com/longformer/' + fname
r = requests.get(url)
open(fname, 'wb').write(r.content)

# !untar longformer-base-4096.tar.gz 
tar = tarfile.open(fname)
tar.extractall()
tar.close()

In [5]:
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base', model_max_length=model_config['max_len'])
train_dataset = ArvixDataset(model_config["datapath"], tokenizer, model_config, mode="train")
val_dataset = ArvixDataset(model_config["datapath"], tokenizer, model_config, mode="validation")
test_dataset = ArvixDataset(model_config["datapath"], tokenizer, model_config, mode="test")

In [6]:
train_dataloader = DataLoader(train_dataset, batch_size=model_config['batch_size'], shuffle=True, collate_fn=None)
val_dataloader = DataLoader(val_dataset, batch_size=model_config['batch_size'], shuffle=True, collate_fn=None)
data = next(iter(train_dataloader))

In [7]:
# test data
input_ids = data["Text"]
attention_mask = data["Attention"]
label = data["Label"]

In [8]:
input_ids, attention_mask, label

(tensor([[    0, 35746,   868,  ...,  8225,   528,     2],
         [    0,   250, 16681,  ..., 42736, 19220,     2]]),
 tensor([[2, 1, 1,  ..., 1, 1, 1],
         [2, 1, 1,  ..., 1, 1, 1]]),
 tensor([[0.],
         [0.]]))

In [9]:
pretrained_embed = LongformerModel.from_pretrained("allenai/longformer-base-4096")
init_weight=pretrained_embed.state_dict()
embed_weight = init_weight['embeddings.word_embeddings.weight'].to(device)

Some weights of LongformerModel were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['longformer.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
model = SwinTransformer(emb_init_weight=embed_weight, embed_dim=96, seq_length=4096)
model = model.to(device)

In [11]:
if model_config["model_weight_path"] is not None:
    file_name = os.path.join(model_config["weight_path"], model_config["model_weight_path"])
    model = torch.load(file_name).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr = model_config['lr'])
loss_fn = torch.nn.CrossEntropyLoss().to(device)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = 5000, last_epoch=-1)

In [12]:
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [14]:
from time import time
for epoch in tqdm(range(model_config["num_epoch"])):
    start = time()
    logger.info("in epoch:" + str(round))
    total_train_loss = 0
    model.train()
    current_lr = scheduler.get_last_lr()
    print(f"Current Learning rate: {current_lr}")
    for step, data in enumerate(train_dataloader):
        start=time()
        input_ids = data["Text"].to(device)
        # attention_mask = data["Attention"].to(device)
        label = data["Label"].to(device)
        optimizer.zero_grad()  
        outputs = model(input_ids)
        loss = loss_fn(outputs, label.squeeze(1).long())
        total_train_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        end=time()
        
        # logger.info("Epoch: " + str(epoch)+ " Step: " + str(step)+ " Loss: " + str(loss.item()) + " Time: " + str(end-start))
        
        if(step % 10 == 0):
            logger.info("Epoch: " + str(epoch) + " Step: " + str(step)+ " Loss: " + str(loss.item()) + " Time: " + str(end-start))
            print(f"Loss after {step} step: {loss}")

    scheduler.step()
    
    avg_train_loss = total_train_loss / len(train_dataloader)   
    print("Average training loss: {0:.2f}".format(avg_train_loss))

    # save model weight
    print("Saving model weight...")

    if not os.path.exists(model_config['weight_path']):
        os.makedirs(model_config['weight_path'])

    weight_file_name = f"{model_config['weight_path']}/e{epoch}_model.pt"
    torch.save(model.state_dict(), weight_file_name)
        
    print("")
    print("Running Validation...")

    # Put the model in evaluation mode-
    model.eval()

    # Tracking variables 
    total_eval_accuracy = 0
    total_eval_loss = 0
    nb_eval_steps = 0

        # Evaluate data for one epoch
    for step, data in enumerate(val_dataloader):
        
        input_ids = data["Text"].to(device)
        # attention_mask = data["Attention"].to(device)
        label = data["Label"].to(device)

        with torch.no_grad():        
            outputs = model(input_ids)
            
        # Accumulate the validation loss.
        loss = loss_fn(outputs, label.squeeze(1).long())
        total_eval_loss += loss.item()

        # Move logits and labels to CPU
        logits = outputs.detach().cpu().numpy()
        label_ids = label.to('cpu').numpy()

        # Calculate the accuracy for this batch of test sentences, and
        # accumulate it over all batches.
        total_eval_accuracy += flat_accuracy(logits, label_ids)

    # Report the final accuracy for this validation run.
    avg_val_accuracy = total_eval_accuracy / len(val_dataloader)
    print("  Accuracy: {0:.2f}".format(avg_val_accuracy))

    # Calculate the average loss over all of the batches.
    avg_val_loss = total_eval_loss / len(val_dataloader)
    
    print("  Validation Loss: {0:.2f}".format(avg_val_loss))
    logger.info("Epoch: " + str(epoch) + "Accuracy: {0:.2f}".format(avg_val_accuracy))
    logger.info("Epoch: " + str(epoch) + "Validation Loss: {0:.2f}".format(avg_val_loss))

print("")
print("Training complete!")

  0%|          | 0/30 [00:00<?, ?it/s]

Current Learning rate: [5e-06]
Loss after 0 step: 2.3053860664367676
Loss after 10 step: 2.674060821533203
Loss after 20 step: 2.5934128761291504
Loss after 30 step: 2.4720776081085205
Loss after 40 step: 2.350025177001953
Loss after 50 step: 2.2207512855529785
Loss after 60 step: 1.995991826057434
Loss after 70 step: 2.880847930908203
Loss after 80 step: 2.418626308441162
Loss after 90 step: 2.7462940216064453
Loss after 100 step: 2.2188327312469482
Loss after 110 step: 2.3361899852752686
Loss after 120 step: 2.500506639480591
Loss after 130 step: 1.972010850906372
Loss after 140 step: 2.062450647354126
Loss after 150 step: 2.3182015419006348
Loss after 160 step: 2.230337142944336
Loss after 170 step: 2.29018235206604


  0%|          | 0/30 [00:20<?, ?it/s]


KeyboardInterrupt: 

In [11]:
import pandas as pd
import numpy as np

class Calculator:
    def __init__(self, num_class=11):
        self.num_class = num_class
        self.dictIdx2Cls = {
            0: "cs.AI",
            1: "cs.cv",
            2: "cs.IT",
            3: "cs.PL",
            4: "math.AC",
            5: "math.ST",
            6: "cs.CE", 
            7: "cs.DS",
            8: "cs.NE",
            9: "cs.SY", 
            10: "math.GR"
        }

    def init_metrics(self):
        class_list = [i for i in range(self.num_class)]
        val_list = [0] * self.num_class

        self.TP = dict(zip(class_list, val_list))
        self.positive_pred = dict(zip(class_list, val_list))
        self.positive_label = dict(zip(class_list, val_list))

        self.precision = dict(zip(class_list, val_list))
        self.recall = dict(zip(class_list, val_list))
        self.f1 = dict(zip(class_list, val_list))

    def update_result(self, preds, labels):
        preds_flat = np.argmax(preds, axis=1).flatten()
        labels_flat = labels.flatten()

        for i in range(self.num_class):

            this_pred = np.array([1 if pred == i else 0 for pred in preds_flat])
            this_label = np.array([1 if label == i else 0 for label in labels_flat])

            self.TP[i] += np.sum(this_pred * this_label)
            self.positive_pred[i] += np.sum(this_pred)
            self.positive_label[i] += np.sum(this_label)

    def get_overall_performance(self):

        precision = sum(self.TP.values()) / sum(self.positive_pred.values())
        recall = sum(self.TP.values()) / sum(self.positive_label.values())
        f1 = (2 * sum(np.array(list(result_calculator.precision.values())) * np.array(list(result_calculator.recall.values())))) / (sum(self.precision.values()) + sum(self.recall.values()))
        # accuracy = sum(self.correct.values()) / sum(self.total.values())
        total = sum(self.positive_label.values())

        return ["overall", total, precision, recall, f1]

    def get_metrics(self):

        for i in range(self.num_class):

            self.precision[i] = (self.TP[i] / self.positive_pred[i]) if self.positive_pred[i] else 0
            self.recall[i] = (self.TP[i] / self.positive_label[i]) if self.positive_label[i] else 0
            self.f1[i] = (2.0 * self.precision[i] * self.recall[i] / (self.precision[i] + self.recall[i])) if (self.precision[i] + self.recall[i]) else 0
            # self.accuracy[i] = self.correct[i] / self.total[i] if self.total[i] else 0
     
        result_dict = {
            "Class": self.dictIdx2Cls.values(),
            "Sample Size": self.positive_label.values(),
            # "Accuracy": self.accuracy.values(),
            "Precision": self.precision.values(),
            "Recall": self.recall.values(),
            "F1": self.f1.values()
        }

        result_df = pd.DataFrame(result_dict)
        result_df.loc[len(result_df.index)] = self.get_overall_performance()

        return result_df



In [12]:
# model = LongformerClassifier(config, pretrain=False, in_features=768, out_features=11).to(device)
model.load_state_dict(torch.load("weight/e18_model.pt"))
print(f"Load model weight from file")

Load model weight from file


In [13]:
test_dataset = ArvixDataset(model_config["datapath"], tokenizer, model_config, mode="test", max_len=model_config["max_len"])
test_dataloader = DataLoader(val_dataset, batch_size=model_config['batch_size'], shuffle=False, collate_fn=None)

In [14]:
total_test_accuracy = 0

result_calculator = Calculator(num_class=11)
result_calculator.init_metrics()

model.eval()

for step, data in enumerate(tqdm(test_dataloader)):
    start = time.time()
    input_ids = data["Text"].to(device)
    attention_mask = data["Attention"].to(device)
    label = data["Label"].to(device)

    with torch.no_grad():        
        outputs = model(input_ids)
        
    # Accumulate the validation loss.
    # loss = loss_fn(outputs, label.squeeze(1).long())
    # total_eval_loss += loss.item()

    # Move logits and labels to CPU
    logits = outputs.detach().cpu().numpy()
    label_ids = label.to('cpu').numpy()

    # Calculate the metrics for this batch of test sentences, and
    # accumulate it over all batches.
    result_calculator.update_result(logits, label_ids)

    # Calculate the accuracy for this batch of test sentences, and
    # accumulate it over all batches.
    total_test_accuracy += flat_accuracy(logits, label_ids)
    end = time.time()
    print(end-start)

# Report the final accuracy for this validation run.
avg_test_accuracy = total_test_accuracy / len(test_dataloader)
print("")
print("Test  Accuracy: {0:.3f}".format(avg_val_accuracy))

# Report the final metrics for this test run.
result_df = result_calculator.get_metrics()

  0%|          | 4/1100 [00:00<01:04, 17.07it/s]

0.025042295455932617
0.023756027221679688
0.023938894271850586
0.0238187313079834


  1%|          | 7/1100 [00:00<01:02, 17.47it/s]

0.023787260055541992
0.023787260055541992
0.023859739303588867
0.023773193359375


  1%|          | 11/1100 [00:00<01:02, 17.36it/s]

0.023787736892700195
0.023849964141845703
0.023890972137451172


  1%|          | 13/1100 [00:00<01:14, 14.58it/s]

0.02393341064453125
0.023925304412841797
0.023807048797607422


  2%|▏         | 17/1100 [00:01<01:14, 14.55it/s]

0.023995637893676758
0.023854732513427734
0.023772478103637695


  2%|▏         | 19/1100 [00:01<01:24, 12.84it/s]

0.023817777633666992
0.023734331130981445
0.023714303970336914


  2%|▏         | 23/1100 [00:01<01:23, 12.87it/s]

0.023799419403076172
0.023929357528686523
0.023777008056640625


  2%|▏         | 27/1100 [00:01<01:13, 14.54it/s]

0.02370619773864746
0.02368760108947754
0.023761272430419922
0.02368640899658203
0.02368783950805664


  3%|▎         | 31/1100 [00:02<01:21, 13.16it/s]

0.023688793182373047
0.02401900291442871
0.02382802963256836


  3%|▎         | 33/1100 [00:02<01:18, 13.60it/s]

0.023967742919921875
0.02379131317138672
0.02373981475830078


  3%|▎         | 37/1100 [00:02<01:23, 12.76it/s]

0.023705720901489258
0.02395462989807129
0.02374434471130371


  4%|▎         | 39/1100 [00:02<01:20, 13.15it/s]

0.023780107498168945
0.023854494094848633
0.0237424373626709


  4%|▍         | 43/1100 [00:03<01:12, 14.57it/s]

0.023663759231567383
0.02385997772216797
0.02364969253540039
0.02377605438232422


  4%|▍         | 44/1100 [00:03<01:15, 14.05it/s]


KeyboardInterrupt: 