#Adaptation of TEDn implementation for Positive-Unlabeled Learning, adapted from original implementation:

https://github.com/acmi-lab/pu_learning



In [None]:
!pip install torch
!pip install transformers

In [None]:
import sys
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
#Define model storage directory and data location:
experiment_dir = 'MODEL STORAGE DIR'
data_dir = 'DATA LOCATION'

# Model helpers

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig


class AutoClassifier(AutoModelForSequenceClassification):
    def __init__(self, config):
      super().__init__(config)

    def __call__(self, x):
        
        input_ids = x[:, :, 0]
        attention_mask = x[:, :, 1]

        outputs = super().__call__(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )[0]
        return outputs

def init_classification_model(model_name):
    config_model = AutoConfig.from_pretrained(model_name)
    config_model.num_labels = 2
    config_model.max_len = 510
    transformer_model = AutoClassifier.from_pretrained(model_name, config=config_model)

    return transformer_model

def get_model(model_type, input_dim=None): 

    if model_type == "Roberta":
        net = init_classification_model('hfl/chinese-roberta-wwm-ext')
        return net 
    else:
        print("Must implement model if model other than Roberta")

# Datahelper

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from transformers import BertTokenizerFast, DistilBertTokenizerFast, AutoTokenizer
from sklearn.model_selection import train_test_split

def read_tifa_split(df):
    
    train_texts = []
    test_texts = []
    train_labels = []
    test_labels = []

    tifa = pd.read_csv(df)

    #Hard coded positive and unknown sample sizes for my use case
    tifa_ones = tifa[tifa.hu_tifa == 1].sample(19000)
    tifa_zeros = tifa[tifa.hu_tifa == 0].sample(97000)
    
    tifa = tifa_ones.append(tifa_zeros).reset_index(drop=True)

    X, y = tifa['para_split'].tolist(), tifa['hu_tifa'].tolist()

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

    return X_train, X_test, y_train, y_test

def getBertTokenizer(model):

    if model == 'hfl/chinese-roberta-wwm-ext':
        tokenizer = AutoTokenizer.from_pretrained('hfl/chinese-roberta-wwm-ext', add_prefix_space=True)
    else:
        raise ValueError(f'Model: {model} not recognized.')

    return tokenizer

def initialize_bert_transform(net):
    # assert 'bert' in config.model
    # assert config.max_token_length is not None

    tokenizer = getBertTokenizer(net)
    
    def transform(text):
        tokens = tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=510)
        
        if net == 'hfl/chinese-roberta-wwm-ext':
            x = np.stack(
                (tokens['input_ids'],
                 tokens['attention_mask'],
                 tokens['token_type_ids']),
                axis=2)
        return x
    
    return transform

class TiFaData(torch.utils.data.Dataset):
    def __init__(self, data, labels, transform):
        labels = np.array(labels)
        
        encodings = transform(data)

        p_data_idx = np.where(labels==1)[0]
        u_data_idx = np.where(labels==0)[0]
        
        self.p_data = encodings[p_data_idx, :, :]
        self.u_data = encodings[u_data_idx, :, :]

        self.labels = labels

        self.transform = None
        self.target_transform = None

    def __len__(self):
        return len(self.labels)

# Get data

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils import data

import torchvision
import torchvision.transforms as transforms

import numpy as np

# New get_dataset func:
class PosData(torch.utils.data.Dataset): 
    def __init__(self, transform=None, target_transform=None, data=None, \
            index=None):
        
        #None for IMDb
        self.transform = transform
        self.target_transform = target_transform

        #The data
        self.data=data

        #Filler for labs
        self.targets = np.zeros(data.shape[0], dtype= np.int_)
        self.data_type = data_type
        self.index = index

    def __len__(self): 
        return len(self.targets)

    def __getitem__(self, idx):
        index, txt, target = self.index[idx],  self.data[idx], self.targets[idx]

        if self.target_transform is not None:
            target = self.target_transform(target)
    
        return index, txt, target
    

class UnlabelData(torch.utils.data.Dataset): 
    def __init__(self, transform=None, target_transform=None, \
            u_data=None, index=None):
        
        #These are non for IMDb
        self.transform = transform
        self.target_transform = target_transform

        #Pull together pos and negatives into single unlabeled
        self.data= u_data
        
        #So it looks like targets are pos = 0, neg = 1
        self.targets = np.ones(u_data.shape[0], dtype= np.int_)

        self.data_type = data_type
        self.index = index

    def __len__(self): 
        return len(self.targets)


    def __getitem__(self, idx):

        index, img, target = self.index[idx],  self.data[idx], self.targets[idx]

        if self.target_transform is not None:
            target = self.target_transform(target)
    
        return index, img, target

def get_PUDataSplits(data_obj): 


    pos_data = data_obj.p_data
    u_data = data_obj.u_data

    unlabel_size = u_data.shape[0]
    pos_size = pos_data.shape[0]

    #These are then passed to the PosData and UnlabeledData functions:
    return PosData(transform=data_obj.transform, \
                target_transform=data_obj.target_transform, \
                data=pos_data, index=np.array(range(pos_size))), \
            UnlabelData(transform=data_obj.transform, \
                target_transform=data_obj.target_transform, \
                u_data = u_data,  \
                index=np.array(range(unlabel_size))), \
                pos_size, \
                unlabel_size
            
def get_dataset(data_dir, data_type, net_type, device, batch_size): 

    p_trainloader=None
    u_trainloader=None
    p_validloader=None
    u_validloader=None
    net=None
    X=None
    Y=None

    if data_type=="TiFa_BERT": 

        #Pull indata
        train_texts, test_texts, train_labels, test_labels = read_tifa_split(data_dir)

        #Bert transform... initializes tokenizer to be passed to roberta
        transform = initialize_bert_transform('hfl/chinese-roberta-wwm-ext')

        #Transforms data using tokenizer, also separates our pos vs. unlabled data and stores these in self
        train_dataset = TiFaData(train_texts, train_labels, transform=transform)
        test_dataset = TiFaData(test_texts, test_labels, transform=transform)

        #Split data
        p_traindata, u_traindata, pos_size, unl_size = get_PUDataSplits(train_dataset)
        p_validdata, u_validdata, _, _ = get_PUDataSplits(test_dataset)

        p_trainloader = torch.utils.data.DataLoader(p_traindata, batch_size=16, \
            shuffle=True)
        u_trainloader = torch.utils.data.DataLoader(u_traindata, batch_size=16, \
            shuffle=True)
        p_validloader = torch.utils.data.DataLoader(p_validdata, batch_size=16, \
            shuffle=True)
        u_validloader = torch.utils.data.DataLoader(u_validdata, batch_size=16, \
            shuffle=True)

        ## Initialize model 
        net = get_model(net_type)
        net = net.to(device)
    
    return p_trainloader, u_trainloader, p_validloader, u_validloader, net, pos_size, unl_size

# Now the Algorithm

In [None]:
import torch
import numpy as np
import os


def sigmoid_loss(out, y): 
    loss = out.gather(1, 1- y.unsqueeze(1)).mean()
    return loss


def validate(epoch, net, u_validloader, criterion, device, threshold, logistic=False, show_bar=True, separate=False):
    
    if show_bar:     
        print('\nTest Epoch: %d' % epoch)
    
    net.eval() 
    test_loss = 0
    correct = 0
    total = 0

    pos_correct = 0
    neg_correct = 0

    pos_total = 0
    neg_total = 0

    if (not logistic) and (criterion is None): 
        # print("here")
        criterion = sigmoid_loss

    with torch.no_grad():
        for batch_idx, (_, inputs, targets) in enumerate(u_validloader):
            
            inputs = inputs.to(device)
            outputs = net(inputs[:,:,0]).to_tuple()[0]
            

            predicted  = torch.nn.functional.softmax(outputs, dim=-1)[:,0] \
                    <= torch.tensor([threshold]).to(device)

            if not logistic: 
                outputs = torch.nn.functional.softmax(outputs, dim=-1)
                
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            total += outputs.size(0)
            
            correct_preds = predicted.eq(targets).cpu().numpy()
            correct += np.sum(correct_preds)

            if separate: 

                true_numpy = targets.cpu().numpy().squeeze()
                pos_idx = np.where(true_numpy==0)[0]
                neg_idx = np.where(true_numpy==1)[0]

                pos_correct += np.sum(correct_preds[pos_idx])
                neg_correct += np.sum(correct_preds[neg_idx])

                pos_total += len(pos_idx)
                neg_total += len(neg_idx)

    
    if not separate: 
        return 100.*correct/total
    else: 
        return 100.*correct/total, 100.*pos_correct/pos_total, 100.*neg_correct/neg_total



def train(epoch, net, p_trainloader, u_trainloader, optimizer, criterion, device, show_bar=True):
    
    net.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, ( p_data, u_data ) in enumerate(zip(p_trainloader, u_trainloader)):
        optimizer.zero_grad()


        #Pull from the data loaders
        _, p_inputs, p_targets = p_data
        _, u_inputs, u_targets = u_data

        #To device
        p_targets = p_targets.to(device)
        u_targets = u_targets.to(device)

        #Cat together into a single input/target set
        inputs =  torch.cat((p_inputs, u_inputs), dim=0)
        targets =  torch.cat((p_targets, u_targets), dim=0)
        inputs = inputs.to(device)

        #Get outputs from BERT
        outputs = net((inputs[:,:,0])).to_tuple()[0]

        #Break into p and u again
        p_outputs = outputs[:len(p_targets)]
        u_outputs = outputs[len(p_targets):]
        
        #Separate losses
        p_loss = criterion(p_outputs, p_targets)
        u_loss = criterion(u_outputs, u_targets)
        
        #Average these
        loss = (p_loss + u_loss)/2.0
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        #Predictions are just the max across the two classes
        _, predicted = outputs.max(1)
        total += targets.size(0)
        
        #eq returns a True/False if element-wise equal to vals in eq(targets)
        correct_preds = predicted.eq(targets).cpu().numpy()
        correct += np.sum(correct_preds)

    return 100.*correct/total


def train_PU_discard(epoch, net,  p_trainloader, u_trainloader, optimizer, criterion, device, keep_sample=None, show_bar=True):
    
    net.train()
    train_loss = 0
    total_p_loss = 0
    total_u_loss = 0
    correct = 0
    total = 0

    for batch_idx, ( p_data, u_data ) in enumerate(zip(p_trainloader, u_trainloader)):
        
        optimizer.zero_grad()
        
        _, p_inputs, p_targets = p_data
        u_index, u_inputs, u_targets = u_data

        u_idx = np.where(keep_sample[u_index.numpy()]==1)[0] #This is a key line!

        if len(u_idx) <1: 
            continue

        u_targets = u_targets[u_idx]

        p_targets = p_targets.to(device)
        u_targets = u_targets.to(device)
        

        u_inputs = u_inputs[u_idx]        
        inputs =  torch.cat((p_inputs, u_inputs), dim=0)
        targets =  torch.cat((p_targets, u_targets), dim=0)
        inputs = inputs.to(device)

        outputs = net(inputs[:,:,0]).to_tuple()[0]

        p_outputs = outputs[:len(p_targets)]
        u_outputs = outputs[len(p_targets):]
        
        p_loss = criterion(p_outputs, p_targets)
        u_loss = criterion(u_outputs, u_targets)

        loss = (p_loss + u_loss)/2.0

        loss.backward()
        optimizer.step()

        total_p_loss += p_loss.item()
        total_u_loss += u_loss.item()
        train_loss += loss.item()

        _, predicted = outputs.max(1)
        total += targets.size(0)
        
        correct_preds = predicted.eq(targets).cpu().numpy()
        correct += np.sum(correct_preds)
        
    #Save model
    model_name = f'best_model_epoch_{epoch}.pt'
    root_model_path = os.path.join('/content/gdrive/MyDrive/Capstone/classification/overall_tifa/hu_models_2', model_name)
    model_dict = net.state_dict()
    state_dict = {'model': model_dict, 'optimizer': optimizer.state_dict()}
    torch.save(state_dict, root_model_path)

    return 100.*correct/total, total_p_loss, total_u_loss

def rank_inputs(_, net, u_trainloader, device, alpha, u_size):

    net.eval() 
    output_probs = np.zeros(u_size)
    keep_samples = np.ones_like(output_probs)
    true_targets_all = np.zeros(u_size)

    with torch.no_grad():
        for batch_num, (idx, inputs, _) in enumerate(u_trainloader):
            
            idx = idx.numpy()
            inputs = inputs.to(device)
            outputs = net(inputs[:,:,0]).to_tuple()[0]
            probs  = torch.nn.functional.softmax(outputs, dim=-1)[:,0]      # This is one because our pos are labeled 1
            output_probs[idx] = probs.detach().cpu().numpy().squeeze()

    sorted_idx = np.argsort(output_probs)
    keep_samples[sorted_idx[u_size - int(alpha*u_size):]] = 0
    
    return keep_samples

#Estimator

In [None]:
import random
import numpy as np
import sys 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms

def p_probs(net, device, p_loader): 
    net.eval()
    pp_probs = None
    with torch.no_grad():
        for batch_idx, (_, inputs, targets) in enumerate(p_loader):
           
            inputs = inputs.to(device)
            outputs = net(inputs[:,:,0]).to_tuple()[0]
            #print('p_prob outputs:', outputs) 

            probs = torch.nn.functional.softmax(outputs, dim=-1)[:,0] 

#             probs = torch.stack((probs, 1-probs), dim=1)
#             probs = probs.to(torch.int32)
            if pp_probs is None: 
                pp_probs = probs.detach().cpu().numpy().squeeze()
            else:
                pp_probs = np.concatenate((pp_probs, \
                    probs.detach().cpu().numpy().squeeze()), axis=0)
    
    return pp_probs    

def u_probs(net, device, u_loader):
    net.eval()
    pu_probs = None
    pu_targets = None
    with torch.no_grad():
        for batch_idx, (_, inputs, targets) in enumerate(u_loader):
            inputs = inputs.to(device)
            outputs = net(inputs[:,:,0]).to_tuple()[0]

            probs = torch.nn.functional.softmax(outputs, dim=-1) 

            if pu_probs is None: 
                pu_probs = probs.detach().cpu().numpy().squeeze()
                pu_targets = targets.numpy().squeeze()
                
            else:
                pu_probs = np.concatenate( (pu_probs, \
                    probs.detach().cpu().numpy().squeeze()))
                pu_targets = np.concatenate( (pu_targets, \
                    targets.numpy().squeeze()))
  
    return pu_probs, pu_targets

def DKW_bound(x,y,t,m,n,delta=0.1, gamma= 0.01):

    temp = np.sqrt(np.log(1/delta)/2/n) + np.sqrt(np.log(1/delta)/2/m)
    bound = temp*(1+gamma)/(y/n)

    estimate = t

    return estimate, t - bound, t + bound


def BBE_estimator(pdata_probs, udata_probs, udata_targets):

    p_indices = np.argsort(pdata_probs)
    sorted_p_probs = pdata_probs[p_indices]

    #This only returns the positive probabilities
    u_indices = np.argsort(udata_probs[:,0])
    sorted_u_probs = udata_probs[:,0][u_indices]
    sorted_u_targets = udata_targets[u_indices]

    #This reverses the order of array
    #These are now sorted from largest to smallest
    sorted_u_probs = sorted_u_probs[::-1]
    sorted_p_probs = sorted_p_probs[::-1]
    sorted_u_targets = sorted_u_targets[::-1]
    num = len(sorted_u_probs)
    estimate_arr = []

    #Now, collecting confidencebounds
    upper_cfb = []
    lower_cfb = []            

    i = 0
    j = 0

    #Num: the number of unlabeled samples
    while (i < num):

        #This is looping through the sorted u_probs from largest to smallest
        start_interval =  sorted_u_probs[i]   

        #Pass if index less than total length, and the val is greater than the next val
        if (i<num-1 and start_interval> sorted_u_probs[i+1]): 
            #If this is the case, continue through loop
            pass
        else: 
            #If this is not the case, add 1 and break this iteration 
            i += 1
            continue
        
        """
        Calc q_hat_p
        """

        while ( j<len(sorted_p_probs) and sorted_p_probs[j] >= start_interval):
            j+= 1

        if j>1 and i > 1:

            #Note that the estimate t accounts for i
            t = (i)*1.0*len(sorted_p_probs)/j/len(sorted_u_probs)
            estimate, lower , upper = DKW_bound(i, j, t, len(sorted_u_probs), len(sorted_p_probs))
            estimate_arr.append(estimate)
            upper_cfb.append(upper)
            lower_cfb.append(lower)
        i+=1

    if (len(upper_cfb) != 0): 

        #This returns the estimate producing the lowest upper-bound
        idx = np.argmin(upper_cfb)
        mpe_estimate = estimate_arr[idx]
        print('mpe_estimate', mpe_estimate)
        return mpe_estimate, lower_cfb, upper_cfb
    else: 
        return 0.0, 0.0, 0.0

# Plot

In [None]:
import matplotlib.pyplot as plt

def plot_acc_alph(alpha_est, accuracy):

    e = len(alpha_est)
    x_axis = np.arange(1, e + 1, 1)
    plt.figure()
    plt.plot(x_axis, alpha_est, label="Alpha Est")
    plt.plot(x_axis, accuracy, label="Accuracy")
    plt.xlabel("Epochs")
    plt.legend(loc='best')
    plt.title("Accuracy and Alpha Estimates")
    plt.savefig("Acc_alpha_plot.png")
    plt.show()


def plot_losses(p_losses, u_losses, alpha_est, accuracy):
    e = len(p_losses)
    x_axis = np.arange(1, e + 1, 1)
    plt.figure()
    plt.plot(x_axis, p_losses, label="Positive Loss")
    plt.plot(x_axis, u_losses, label="Unlabeled Loss")
    plt.xlabel("Epochs")
    plt.legend(loc='best')
    plt.title("Train until P and U loss Converges")
    plt.savefig("loss_plot.png")
    plt.show()

    e = len(alpha_est)
    x_axis = np.arange(1, e + 1, 1)
    plt.figure()
    plt.plot(x_axis, alpha_est, label="Alpha Est")
    plt.xlabel("Epochs")
    plt.legend(loc='best')
    plt.title("Alpha Estimates")
    plt.savefig("Acc_alpha_plot.png")
    plt.show()

    e = len(alpha_est)
    x_axis = np.arange(1, e + 1, 1)
    plt.figure()
    plt.plot(x_axis, accuracy, label="Accuracy")
    plt.xlabel("Epochs")
    plt.legend(loc='best')
    plt.title("Accuracy Estimates")
    plt.savefig("Acc_alpha_plot.png")
    plt.show()

# Initialize model and data

In [None]:
import os
import argparse
import time
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from transformers import AdamW

torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
random.seed(42)

lr = 2e-7
wd = 5e-4
momentum = 0.9
seed = 42

net_type = "Roberta"
data_type = "TiFa_BERT"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device count', torch.cuda.device_count())
train_method = "TEDn"

## Train set for positive and unlabeled
alpha = 0.5
beta = 0.15
warm_start = True
warm_start_epochs = 5
batch_size = 8
epochs= 5000
log_dir = './TiFa_BERT'
optimizer_str='AdamW'
alpha_estimate=0.1
show_bar = False
use_alpha = False
estimate_alpha = True
load_model = False
model_name = 'Roberta'
if train_method == "TEDn": 
    use_alpha=True

#################

## Obtain dataset 
p_trainloader, u_trainloader, p_validloader, u_validloader, net, train_pos_size, train_unl_size = \
    get_dataset(data_dir, data_type, net_type, device, batch_size)

if load_model == True:
    best_checkpoint = torch.load(os.path.join(experiment_dir, model_name))
    net.load_state_dict(best_checkpoint['model'])            
    net = net.to(device)

    if optimizer_str=="AdamW": 
      optimizer = AdamW(net.parameters(), lr=lr)
      optimizer.load_state_dict(best_checkpoint['optimizer'])

if torch.cuda.is_available():
    net = net.cuda()
    criterion = nn.CrossEntropyLoss().cuda()
else:
  print('failed to place net and criterion on cuda')

if load_model == False:
  if optimizer_str=="SGD":
      optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=wd)
  elif optimizer_str=="Adam":
      optimizer = optim.Adam(net.parameters(), lr=lr,weight_decay=wd)
  elif optimizer_str=="AdamW": 
      optimizer = AdamW(net.parameters(), lr=lr)

# Warm-Up Epochs to establish baseline estimate of positive labeled proportion

In [None]:
alpha_estimate=0.1
## Train in the begining for warm start

if warm_start and train_method=="TEDn": 
  print('employing warm start')

  alph_est = []
  train_acc_list = []

  for epoch in range(warm_start_epochs): 
      
      train_acc = train(epoch, net, p_trainloader, u_trainloader, \
              optimizer=optimizer, criterion=criterion, device=device, show_bar=show_bar)

      train_acc_list.append(train_acc)

      if estimate_alpha: 
          pos_probs = p_probs(net, device, p_validloader)
          unlabeled_probs, unlabeled_targets = u_probs(net, device, u_validloader)


          our_mpe_estimate, _, _ = BBE_estimator(pos_probs, unlabeled_probs, unlabeled_targets)

          alpha_estimate =our_mpe_estimate
          alph_est.append(alpha_estimate)
          print("current alpha estimate:", alpha_estimate)

      plot_acc_alph(alph_est, train_acc_list)

# Now, run epochs employing Warm-up estimated alpha

In [None]:
import time
epochs = 40

t0 = time.time()

if train_method=='CVIR' or train_method=="TEDn": 

    print('now, no warm start.')
    alpha_used = alpha_estimate


    t0 = time.time()

    #Store losses
    p_loss_list = []
    u_loss_list = []
    epoch_list = []
    total_loss_list = []
    acc_list = []

    alpha_est = []
    for epoch in range(epochs):
        
        if use_alpha: 
            alpha_used =  alpha_estimate
        else:
            alpha_used = alpha
        
        keep_samples = rank_inputs(epoch, net, u_trainloader, device,\
             alpha_used, u_size=train_unl_size)

        train_acc, p_loss, u_loss = train_PU_discard(epoch, net,  p_trainloader, u_trainloader,\
            optimizer, criterion, device, keep_sample=keep_samples,show_bar=show_bar)
        
        print('train_acc', train_acc)

        total_loss = p_loss + u_loss

        acc_list.append(train_acc/100)

        epoch_list.append(epoch)
        p_loss_list.append(p_loss)
        u_loss_list.append(u_loss)
        total_loss_list.append(total_loss)

        if estimate_alpha: 
            pos_probs = p_probs(net, device, p_validloader)
            unlabeled_probs, unlabeled_targets = u_probs(net, device, u_validloader)
            our_mpe_estimate, _, _ = BBE_estimator(pos_probs, unlabeled_probs, unlabeled_targets)

            print('Current estimate after', epoch, 'rounds:', our_mpe_estimate)
            alpha_estimate = our_mpe_estimate
            alpha_est.append(alpha_estimate)
            print('{} minutes this epoch'.format(round((time.time() - t0)/60), 2))
            t0 = time.time()

        plot_losses(p_loss_list, u_loss_list, alpha_est, acc_list)


# Define inference-tailored data loader and inference function

In [None]:
import torch
import os
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from tqdm import tqdm
from ast import literal_eval
import numpy as np
from transformers import AutoTokenizer

class ArticleDataset(Dataset):

    
    def __init__(self, dataframe, tokenizer, max_len, get_wids, inference):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.inference = inference

    def __getitem__(self, index):
        # GET TEXT AND WORD LABELS 
        text = self.data.para_split[index]

        if self.inference == False:    
          label = self.data.label[index]  
        
        article_id = self.data.cnki_id[index]

        # TOKENIZE TEXT
        encoding = self.tokenizer(text, 
                                padding='max_length', 
                                truncation=True, 
                                max_length=510)
 
        # CONVERT TO TORCH TENSORS
        item = {key: torch.as_tensor(val) for key, val in encoding.items()}
        if self.inference == False:
          item['label'] = torch.as_tensor(label)
        
        item['cnki_id'] = article_id
        return item

    def __len__(self):
        return self.len


def inference_data_processing(data_dir, model_name):


  tokenizer = getBertTokenizer('hfl/chinese-roberta-wwm-ext')
  
  #Pull in dataset
  df = pd.read_csv(data_dir)

  #Keep original
  df['idx'] = df.index.values

  # CREATE TRAIN SUBSET AND VALID SUBSET
  infer_dataset = df[['cnki_id', 'para_split']]


  # tokenizer = AutoTokenizer.from_pretrained(config_data['model']['transformer_path']) 
  infer_set = ArticleDataset(infer_dataset, tokenizer, 510, False, True)

  infer_loader = DataLoader(infer_set,
                            batch_size=8,
                            shuffle=False, num_workers=2,
                            pin_memory=True)

  return infer_loader, infer_dataset


def inference(infer_loader, model_type, exp_dir, model_loc, device, threshold):
    
    #Load model
    net = get_model(model_type)
    best_checkpoint = torch.load(os.path.join(experiment_dir, model_loc))
    net.load_state_dict(best_checkpoint['model'])            
    net = net.to(device)
    net.eval() 

    loss_list = []
    label_list = []
    pred_list = []
    cnki_list = []
    logit_list = []
  
    with torch.no_grad():
        for idx, batch in enumerate(infer_loader):

            ids = batch['input_ids'].to(device, dtype = torch.long)
            mask = batch['attention_mask'].to(device, dtype = torch.long)
            cnki_ids = batch['cnki_id']

            outputs = net(ids, attention_mask=mask, return_dict=False)[0]

            logits = torch.nn.functional.softmax(outputs, dim=-1)[:,0]
            predicted  = (logits \
                    <= torch.tensor([threshold]).to(device)).cpu().numpy() 

            logit_list.append(logits)
            cnki_list.append(cnki_ids)
            pred_list.append(predicted)

    return pred_list, cnki_list, logit_list


#Initialize inference data and model

In [None]:
import os
import argparse
import time
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from transformers import AdamW

torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
random.seed(42)
torch.cuda.empty_cache()

lr = 5e-6
wd = 5e-4
momentum = 0.9
seed = 42

net_type = "Roberta"
data_type = "TiFa_BERT"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device count', torch.cuda.device_count())
train_method = "TEDn"

alpha = 0.01
beta = 0.15
warm_start = True
warm_start_epochs = 1
batch_size = 8
epochs= 5000
log_dir = './TiFa_BERT'
optimizer_str='AdamW'
alpha_estimate=0.005
show_bar = False
use_alpha = False
estimate_alpha = True
load_model = False
if train_method == "TEDn": 
    use_alpha=True

infer_loader, infer_df = inference_data_processing(data_dir, model_name='Roberta')

#Perform inference

In [None]:
pred_list, cnki_list, logit_list = inference(infer_loader, 'Roberta', experiment_dir, 'best_model.pt', device, 0.5)

# Align inference with original dataframe

In [None]:
df = pd.read_csv(data_dir)
cnki_all = [x for y in cnki_list for x in y]
pred_all = [x for y in pred_list for x in y]
#logit_all = [x for y in logits for x in y]

df['inf_lab'] = pred_all
df['inf_lab'] = df.inf_lab.map({True:0, False:1})


In [None]:
df.to_csv("hu_inference.csv", index = False)
!mv '/content/hu_inference.csv' '/content/gdrive/MyDrive/Capstone/classification/hu_inference_final.csv'