In [1]:
import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer


import pandas as pd
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer
import os
from tqdm import tqdm
from sklearn.metrics import f1_score, roc_auc_score, roc_curve
from sklearn.metrics import classification_report

from transformers import PreTrainedTokenizerFast, XLNetTokenizerFast

from tokenizers import Tokenizer, normalizers, pre_tokenizers
from tokenizers.models import WordLevel
from tokenizers.normalizers import NFD, Lowercase, StripAccents
from tokenizers.pre_tokenizers import Digits, Whitespace, Punctuation
from tokenizers.trainers import WordLevelTrainer
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/jonibekmansurov/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
def cosine(u, v):
    return np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v))

In [3]:
sbert_model = SentenceTransformer('bert-base-nli-mean-tokens')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
constitution = pd.read_csv('data/external/constitution_of_india.csv')
constitution

Unnamed: 0,Articles
0,"1. Name and territory of the Union\n(1) India,..."
1,1. The territories of the States; the Union te...
2,2. Admission or establishment of new States: P...
3,2A. Sikkim to be associated with the Union Rep...
4,3. Formation of new States and alteration of a...
...,...
451,378A. Special provision as to duration of Andh...
452,392. Power of the President to remove difficul...
453,393. Short title This Constitution may be call...
454,"394. Commencement This article and Articles 5,..."


In [5]:
data = pd.read_csv('data/final_data/processed_data.csv')
data.fillna('[NOArt]', inplace=True)
data

Unnamed: 0,processed_text,constitution_article,crpc_article,ipc_article,label,split
0,F NARIMAN J Leave granted In 2008 the Punjab ...,~~Equality before law The State shall not deny...,[NOArt],[NOArt],1,train
1,S THAKUR J Leave granted These appeals are di...,~~Power of Parliament to amend the Constitutio...,[NOArt],[NOArt],0,train
2,Markandey Katju J Leave granted Heard learned...,[NOArt],~~Order for maintenance of wives children and ...,[NOArt],1,train
3,ALTAMAS KABIRJ Leave granted The question whe...,~~Power of High Courts to issue certain writs1...,~~Compounding of offences The offences punisha...,~~Whoever commits murder shall be punished wit...,1,train
4,CIVIL APPEAL NO 598 OF 2007 K MATHUR J This a...,~~Jurisdiction of existing High Courts Subject...,[NOArt],[NOArt],1,train
...,...,...,...,...,...,...
6071,civil appellate jurisdiction civil appeal numb...,[NOArt],[NOArt],[NOArt],1,dev
6072,criminal appellate jurisdiction special leave\...,[NOArt],[NOArt],[NOArt],0,dev
6073,civil appellate jurisdiction civil appeal numb...,[NOArt],[NOArt],[NOArt],0,dev
6074,civil appellate jurisdiction civil appeal numb...,[NOArt],[NOArt],[NOArt],1,dev


In [6]:
# X_train_dev, X_test, y_train_dev, y_test = train_test_split(data.drop(columns=['label']),
#                                                     data['label'], test_size=0.2, random_state=10)
#
# X_train, X_dev, y_train, y_dev = train_test_split(X_train_dev,
#                                                   y_train_dev, test_size=0.1, random_state=10)

In [7]:
train = data.loc[data.split == 'train']
dev = data.loc[data.split == 'dev']

X_train = train.drop(columns=['label'])
X_dev = dev.drop(columns=['label'])
y_train = train.label
y_dev = dev.label

In [8]:
X_train = X_train.reset_index(drop=True)
X_dev = X_dev.reset_index(drop=True)
# X_test = X_test.reset_index(drop=True)
y_train = y_train.reset_index(drop=True)
y_dev = y_dev.reset_index(drop=True)
# y_test = y_test.reset_index(drop=True)

In [9]:
# train a tokenizer, initialize WordLevel tokenizer
tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
# we first define a normalizer applied before tokenization
tokenizer.normalizer = normalizers.Sequence([NFD(), Lowercase(), StripAccents()])
# pre-tokenizer defines a "preprocessing" before the tokenization.
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), Punctuation(),
                                                   Digits(individual_digits=True)])
# training a tokenizer is effectively building a vocabulary in this case
trainer = WordLevelTrainer(vocab_size=50000, special_tokens=["[PAD]", "[UNK]"])
tokenizer.train_from_iterator(X_train.processed_text.values, trainer=trainer)
tokenizer.save("tokenizer.json")

#load a tokenizer
tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="tokenizer.json",
    unk_token="[UNK]",
    pad_token="[PAD]"
)


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


seed = 0
torch.manual_seed(seed)


class LegEval(Dataset):
    def __init__(self, legal_docs, constitutions, CrPcs, IPCs, labels, tokenizer, max_token_len=512):
        self.legal_docs = legal_docs
        self.constitutions = constitutions
        self.CrPcs = CrPcs
        self.IPCs = IPCs

        self.labels = labels
        self.tokenizer = tokenizer
        self.max_token_len = max_token_len

    def __len__(self):
        return len(self.legal_docs)

    def __getitem__(self, idx: int):
        legal_doc = self.legal_docs[idx]

        constitution = self.constitutions[idx]
        CrPc = self.CrPcs[idx]
        IPC = self.IPCs[idx]
        label = self.labels[idx]

        encoding = self.tokenizer.encode(
            legal_doc,
            padding='max_length',
            max_length=self.max_token_len,
            truncation=True,
            return_tensors='pt'
        )

        constitution_encoding = self.tokenizer.encode(
            constitution,
            padding='max_length',
            max_length=self.max_token_len,
            truncation=True,
            return_tensors='pt'
        )

        CrPc_encoding = self.tokenizer.encode(
            CrPc,
            padding='max_length',
            max_length=self.max_token_len,
            truncation=True,
            return_tensors='pt'
        )

        IPC_encoding = self.tokenizer.encode(
            IPC,
            padding='max_length',
            max_length=self.max_token_len,
            truncation=True,
            return_tensors='pt'
        )

        return dict(
            input_ids=encoding,
            input_constitution=constitution_encoding,
            input_CrPc=CrPc_encoding,
            input_IPC=IPC_encoding,
            label=torch.tensor([label], dtype=torch.float),
        )

In [11]:
train_dataset = LegEval(
    X_train.processed_text, X_train.constitution_article,
    X_train.crpc_article, X_train.ipc_article,
    y_train, tokenizer
)

# test_dataset = LegEval(
#     X_test.processed_text, X_test.constitution_article,
#     X_test.crpc_article, X_test.ipc_article,
#     y_test, tokenizer
# )

dev_dataset = LegEval(
    X_dev.processed_text, X_dev.constitution_article,
    X_dev.crpc_article, X_dev.ipc_article,
    y_dev, tokenizer
)

In [13]:
BATCH_SIZE = 64 # batch size for training

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
dev_dataloader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False)
# test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)AXC

In [14]:
class CNNClassifier_1(nn.Module):
    def __init__(self,
                 vocab_size,
                 output_size,
                 embedding_size=300,
                 in_channels=1,
                 out_channels=100,
                 kernel_sizes=[3,4,5]):
        super(CNNClassifier_1, self).__init__()

        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.output_size = output_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_sizes = kernel_sizes

        self.embed = nn.Embedding(self.vocab_size, self.embedding_size)
        self.convs = nn.ModuleList(
            [nn.Conv2d(self.in_channels, self.out_channels,
                       (kernel_size, self.embedding_size))
             for kernel_size in self.kernel_sizes])
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(len(self.kernel_sizes) * self.out_channels,
                             self.output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.embed(torch.squeeze(x))  # (batch_size, sequence_length, embedding_size)
        x = x.unsqueeze(1)  # (batch_size, in_channels, sequence_length, embedding_size)
        x = [F.relu(conv(x)).squeeze(3) for conv in self.convs]  # [(batch_size, out_channels, embedding_size), ...]*len(kernel_sizes)
        x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]  # [(batch_size, out_channels), ...]*len(kernel_sizes)
        x = torch.cat(x, 1)  # (batch_size, len(kernel_sizes)*out_channels)
        x = self.dropout(x)  # (batch_size, len(kernel_sizes)*out_channels)
        y = self.fc1(x)  # (batch_size, output_size)
        return y

    def predict(self, x, threshold=0.5):
        preds = self.sigmoid(self.forward(x))
        preds = np.array(preds.cpu() > threshold, dtype=float)
        return preds

In [15]:
class CNNClassifier_2(nn.Module):
    def __init__(self,
                 vocab_size,
                 output_size,
                 embedding_size=300,
                 in_channels=1,
                 out_channels=100,
                 kernel_sizes=[3,4,5]):
        super(CNNClassifier_2, self).__init__()

        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.output_size = output_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_sizes = kernel_sizes
        self.output2=50
        self.embed = nn.Embedding(self.vocab_size, self.embedding_size)
        self.input2 = nn.Linear(embedding_size, self.output2)
        self.convs = nn.ModuleList(
            [nn.Conv2d(self.in_channels, self.out_channels,
                       (kernel_size, self.embedding_size))
             for kernel_size in self.kernel_sizes])
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(len(self.kernel_sizes) * self.out_channels+self.output2,
                             self.output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, labels):
        x = self.embed(torch.squeeze(x))  # (batch_size, sequence_length, embedding_size)
        # print(labels)
        l = self.embed(torch.squeeze(labels))
        input2 = self.input2(l)
        x = x.unsqueeze(1)  # (batch_size, in_channels, sequence_length, embedding_size)
        x = [F.relu(conv(x)).squeeze(3) for conv in self.convs]  # [(batch_size, out_channels, embedding_size), ...]*len(kernel_sizes)
        x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]  # [(batch_size, out_channels), ...]*len(kernel_sizes)
        x = torch.cat(x, 1)  # (batch_size, len(kernel_sizes)*out_channels)
        x = self.dropout(x)  # (batch_size, len(kernel_sizes)*out_channels)
        combined = torch.cat((x.view(x.size(0), -1),
                              input2.view(input2.size(0), -1)), dim=1)
        y = self.fc1(combined)  # (batch_size, output_size)
        return y

    def predict(self, x, labels, threshold=0.5):
        preds = self.sigmoid(self.forward(x, labels))
        preds = np.array(preds.cpu() > threshold, dtype=float)
        return preds

In [16]:
class RnnType:
    GRU = 1
    LSTM = 2

class AttentionModel:
    NONE = 0
    DOT = 1
    GENERAL = 2

class Parameters:
    def __init__(self, data_dict):
        for k, v in data_dict.items():
            exec("self.%s=%s" % (k, v))


class Attention(nn.Module):
    def __init__(self, device, method, hidden_size):
        super(Attention, self).__init__()
        self.device = device

        self.method = method
        self.hidden_size = hidden_size

        self.concat_linear = nn.Linear(self.hidden_size * 2, self.hidden_size)

        if self.method == AttentionModel.GENERAL:
            self.attn = nn.Linear(self.hidden_size, hidden_size)

    def forward(self, rnn_outputs, final_hidden_state):
        # rnn_output.shape:         (batch_size, seq_len, hidden_size)
        # final_hidden_state.shape: (batch_size, hidden_size)
        # NOTE: hidden_size may also reflect bidirectional hidden states (hidden_size = num_directions * hidden_dim)
        batch_size, seq_len, _ = rnn_outputs.shape
        if self.method == AttentionModel.DOT:
            attn_weights = torch.bmm(rnn_outputs, final_hidden_state.unsqueeze(2))
        elif self.method == AttentionModel.GENERAL:
            attn_weights = self.attn(rnn_outputs) # (batch_size, seq_len, hidden_dim)
            attn_weights = torch.bmm(attn_weights, final_hidden_state.unsqueeze(2))

        else:
            raise Exception("[Error] Unknown AttentionModel.")

        attn_weights = torch.softmax(attn_weights.squeeze(2), dim=1)

        context = torch.bmm(rnn_outputs.transpose(1, 2), attn_weights.unsqueeze(2)).squeeze(2)

        attn_hidden = torch.tanh(self.concat_linear(torch.cat((context, final_hidden_state), dim=1)))

        return attn_hidden, attn_weights


class RnnClassifier(nn.Module):
    def __init__(self, device, params):
        super(RnnClassifier, self).__init__()
        self.params = params
        self.device = device

        # Embedding layer
        self.word_embeddings = nn.Embedding(self.params.vocab_size, self.params.embed_dim)

        # Calculate number of directions
        self.num_directions = 2 if self.params.bidirectional == True else 1

        self.linear_dims = [self.params.rnn_hidden_dim * self.num_directions] + self.params.linear_dims
        self.linear_dims.append(self.params.label_size)

        # RNN layer
        rnn = None
        if self.params.rnn_type == RnnType.GRU:
            rnn = nn.GRU
        elif self.params.rnn_type == RnnType.LSTM:
            rnn = nn.LSTM
        else:
            raise Exception("[Error] Unknown RnnType. Currently supported: RnnType.GRU=1, RnnType.LSTM=2")
        self.rnn = rnn(self.params.embed_dim,
                       self.params.rnn_hidden_dim,
                       num_layers=self.params.num_layers,
                       bidirectional=self.params.bidirectional,
                       dropout=self.params.dropout,
                       batch_first=False)


        # Define set of fully connected layers (Linear Layer + Activation Layer) * #layers
        self.linears = nn.ModuleList()
        for i in range(0, len(self.linear_dims)-1):
            if self.params.dropout > 0.0:
                self.linears.append(nn.Dropout(p=self.params.dropout))
            linear_layer = nn.Linear(self.linear_dims[i], self.linear_dims[i+1])
            self.init_weights(linear_layer)
            self.linears.append(linear_layer)
            if i == len(self.linear_dims) - 1:
                break  # no activation after output layer!!!
            self.linears.append(nn.ReLU())

        self.hidden = None

        # Choose attention model
        if self.params.attention_model != AttentionModel.NONE:
            self.attn = Attention(self.device, self.params.attention_model, self.params.rnn_hidden_dim * self.num_directions)
        self.sigmoid = nn.Sigmoid()


    def init_hidden(self, batch_size):
        if self.params.rnn_type == RnnType.GRU:
            return torch.zeros(self.params.num_layers * self.num_directions, batch_size, self.params.rnn_hidden_dim).to(self.device)
        elif self.params.rnn_type == RnnType.LSTM:
            return (torch.zeros(self.params.num_layers * self.num_directions, batch_size, self.params.rnn_hidden_dim).to(self.device),
                    torch.zeros(self.params.num_layers * self.num_directions, batch_size, self.params.rnn_hidden_dim).to(self.device))
        else:
            raise Exception('Unknown rnn_type. Valid options: "gru", "lstm"')

    # def freeze_layer(self, layer):
    #     for param in layer.parameters():
    #         param.requires_grad = False


    def forward(self, inputs):
        batch_size, seq_len, ems = inputs.shape

        # Push through embedding layer
        X = self.word_embeddings(torch.squeeze(inputs)).transpose(0, 1)

        self.hidden = self.init_hidden(batch_size)
        # Push through RNN layer
        rnn_output, self.hidden = self.rnn(X, self.hidden)

        # Extract last hidden state
        final_state = None
        if self.params.rnn_type == RnnType.GRU:
            final_state = self.hidden.view(self.params.num_layers, self.num_directions, batch_size, self.params.rnn_hidden_dim)[-1]
        elif self.params.rnn_type == RnnType.LSTM:
            final_state = self.hidden[0].view(self.params.num_layers, self.num_directions, batch_size, self.params.rnn_hidden_dim)[-1]
        # Handle directions
        final_hidden_state = None
        if self.num_directions == 1:
            final_hidden_state = final_state.squeeze(0)
        elif self.num_directions == 2:
            h_1, h_2 = final_state[0], final_state[1]
            final_hidden_state = torch.cat((h_1, h_2), 1)  # Concatenate both states

        # Push through attention layer
        if self.params.attention_model != AttentionModel.NONE:
            rnn_output = rnn_output.permute(1, 0, 2)  #
            X = self.attn(rnn_output, final_hidden_state)[0]
        else:
            X = final_hidden_state

        # Push through linear layers
        for l in self.linears:
            X = l(X)

        return X


    def init_weights(self, layer):
        if type(layer) == nn.Linear:
            # print("Initialize layer with nn.init.xavier_uniform_: {}".format(layer))
            torch.nn.init.xavier_uniform_(layer.weight)
            layer.bias.data.fill_(0.01)

    def predict(self, x, threshold=0.5):
        preds = self.sigmoid(self.forward(x))
        # print(preds)
        preds = np.array(preds.cpu() > threshold, dtype=float)
        # print(preds)
        return preds

In [17]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss >= (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [18]:
parameters_dictionary = {}
parameters = Parameters({'vocab_size': tokenizer.vocab_size, 'embed_dim': 200,
                         'rnn_hidden_dim': 300, 'bidirectional': True, 'linear_dims': [300, 1],
                         'label_size': 1, 'rnn_type': RnnType.GRU, 'num_layers': 1,
                         'dropout': 0.0, 'attention_model': AttentionModel.GENERAL}
                        )

In [19]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Hyperparameters
EPOCHS = 30 # epoch
LR = 0.01  # learning rate

# model = CNNClassifier_1(
#     tokenizer.vocab_size,
#     1
# )

model = RnnClassifier(
    torch.device(device),
    parameters
)

# model = CNNClassifier_2(
#     tokenizer.vocab_size,
#     2
# )
model.to(device)

loss_fun = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

early_stopper = EarlyStopper(patience=3, min_delta=0.02)
for epoch in range(1, EPOCHS + 1):
    epoch_loss = 0
    model.train()
    for idx, data in enumerate(train_dataloader):
        # print(data['label'])
        # print(idx)
        optimizer.zero_grad()
        outputs = model(data['input_ids'].to(device))
        # print(outputs)
        loss = loss_fun(outputs, data['label'].to(device))
        loss.backward()
        # print(model.linears[4].weight.grad)
        optimizer.step()
        epoch_loss += loss.item()

    model.eval()
    outputs = []
    targets = []
    with torch.no_grad():
        for idx, data in enumerate(dev_dataloader):

            output_batch = model.predict(data['input_ids'].to(device))
            target_batch = np.array(data['label'])
            outputs.extend(output_batch)
            targets.extend(target_batch)
            # dev_dataloader[idx]['label_1_output'] = outputs

    micro_f1 = f1_score(targets, outputs, average='micro')
    dev_loss = loss_fun(torch.FloatTensor(outputs), torch.FloatTensor(targets))
    if early_stopper.early_stop(dev_loss):
        break
    print(f'\rEpoch: {epoch}/{EPOCHS}, Micro-f1: {micro_f1:.3f}, Train Loss: {epoch_loss/len(train_dataloader):.3f}, Dev Loss: {dev_loss:.3f}', end='')

  dev_loss = loss_fun(torch.FloatTensor(outputs), torch.FloatTensor(targets))


Epoch: 9/30, Micro-f1: 0.555, Train Loss: 0.044, Dev Loss: 0.718

In [20]:
print(classification_report(targets, outputs))

dev_dataset = LegEval(
    X_dev.processed_text, X_dev.constitution_article,
    X_dev.crpc_article, X_dev.ipc_article,
    y_dev, tokenizer
)

dev_dataloader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False)

              precision    recall  f1-score   support

         0.0       0.53      0.58      0.55       497
         1.0       0.54      0.49      0.51       497

    accuracy                           0.53       994
   macro avg       0.53      0.53      0.53       994
weighted avg       0.53      0.53      0.53       994



In [None]:
# train_dataset = LegEval(
#     X_train.processed_text, X_train.constitution_article,
#     X_train.crpc_article, X_train.ipc_article,
#     y_train, tokenizer
# )
#
# test_dataset = LegEval(
#     X_test.processed_text, X_test.constitution_article,
#     X_test.crpc_article, X_test.ipc_article,
#     y_test, tokenizer
# )
#
# dev_dataset = LegEval(
#     X_dev.processed_text, X_dev.constitution_article,
#     X_dev.crpc_article, X_dev.ipc_article,
#     y_dev, tokenizer
# )

In [None]:
model.eval()
train_outputs = []
targets = []
with torch.no_grad():
    for idx, data in enumerate(train_dataloader):
        output_batch = model.predict(data['input_ids'].to(device))
        target_batch = np.array(data['label'])
        train_outputs.extend(output_batch)
        targets.extend(target_batch)

print(classification_report(targets, train_outputs))

train_dataset = LegEval(
    X_train.processed_text, X_train.constitution_article,
    X_train.crpc_article, X_train.ipc_article,
    y_train, tokenizer
)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
# model.eval()
# outputs_test = []
# targets = []
# with torch.no_grad():
#     for idx, data in enumerate(test_dataloader):
#         output_batch = model.predict(data['input_ids'].to(device))
#         target_batch = np.array(data['label'])
#         outputs_test.extend(output_batch)
#         targets.extend(target_batch)
#
# print(classification_report(targets, outputs_test))
#
# test_dataset = LegEval(
#     X_test.processed_text, X_test.constitution_article,
#     X_test.crpc_article, X_test.ipc_article,
#     y_test, tokenizer
# )
#
# test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Hyperparameters
EPOCHS = 30 # epoch
LR = 0.001  # learning rate

model = CNNClassifier_2(
    tokenizer.vocab_size,
    2
)
model.to(device)

loss_fun = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

early_stopper = EarlyStopper(patience=3, min_delta=0.02)
for epoch in range(1, EPOCHS + 1):
    epoch_loss = 0
    model.train()
    for idx, data in enumerate(train_dataloader):
        # print(data['label'])
        # print(idx)
        optimizer.zero_grad()
        outputs = model(data['input_ids'].to(device), data['input_constitution'])
        # print(outputs)
        loss = loss_fun(outputs, data['label'].to(device))
        loss.backward()
        # print(model.linears[4].weight.grad)
        optimizer.step()
        epoch_loss += loss.item()

    model.eval()
    outputs = []
    targets = []
    with torch.no_grad():
        for idx, data in enumerate(dev_dataloader):

            output_batch = model.predict(data['input_ids'].to(device), data['input_constitution'])
            target_batch = np.array(data['label'])
            outputs.extend(output_batch)
            targets.extend(target_batch)
            # dev_dataloader[idx]['label_1_output'] = outputs

    micro_f1 = f1_score(targets, outputs, average='micro')
    dev_loss = loss_fun(torch.FloatTensor(outputs), torch.FloatTensor(targets))
    if early_stopper.early_stop(dev_loss):
        break
    print(f'\rEpoch: {epoch}/{EPOCHS}, Micro-f1: {micro_f1:.3f}, Train Loss: {epoch_loss/len(train_dataloader):.3f}, Dev Loss: {dev_loss:.3f}', end='')

In [None]:
# torch.has_mps