In [1]:
import torch
import pandas as pd
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer

In [2]:
class Dataset:
    def __init__(self, text, tokenizer, max_len):
        self.text = text
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.text)

    def __getitem__(self, item):
        text = str(self.text[item])
        inputs = self.tokenizer(
            text, 
            max_length=self.max_len, 
            padding="max_length", 
            truncation=True
        )

        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]

        return {
            "input_ids": torch.tensor(ids, dtype=torch.long),
            "attention_mask": torch.tensor(mask, dtype=torch.long),
        }

In [3]:
def generate_predictions(model_path, max_len):
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    model.to("cuda")
    model.eval()
    
    df = pd.read_csv("../input/jigsaw-toxic-severity-rating/comments_to_score.csv")
    
    dataset = Dataset(text=df.text.values, tokenizer=tokenizer, max_len=max_len)
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=32, num_workers=4, pin_memory=True, shuffle=False
    )

    final_output = []

    for b_idx, data in enumerate(data_loader):
        with torch.no_grad():
            for key, value in data.items():
                data[key] = value.to("cuda")
            output = model(**data)
            output = output.logits.detach().cpu().numpy()[:, 1].ravel().tolist()
            final_output.extend(output)
    
    torch.cuda.empty_cache()
    return np.array(final_output)

In [4]:
preds = generate_predictions("../input/autonlp-toxic-1/", max_len=192)

  cpuset_checked))


In [5]:
sub = pd.read_csv("../input/jigsaw-toxic-severity-rating/comments_to_score.csv")
sub["score"] = preds
sub = sub[["comment_id", "score"]]
sub.to_csv("submission.csv", index=False)

In [6]:
sub.head()

Unnamed: 0,comment_id,score
0,114890,-3.361697
1,732895,-2.418945
2,1139051,-1.980471
3,1434512,-3.791222
4,2084821,0.711684
