## 2. Clustering

In [1]:
from tqdm import trange, notebook
import pandas as pd
import numpy as np
import random
import warnings
import time
import datetime
import re
import string
import itertools
import pickle
import joblib
import nltk
import csv


import tensorflow as tf
import keras
import keras.backend as K
from keras.preprocessing.text import Tokenizer
from keras.utils import pad_sequences
from keras.layers import Input, Concatenate, Conv2D, Flatten, Dense, Embedding, LSTM
from keras.models import Model
from keras.utils import np_utils

from sklearn.model_selection import train_test_split

In [2]:
base_url = '/content/drive/MyDrive/TAPT/'

In [3]:
task_initial = pd.read_csv(base_url+'dev.csv')['text'].values.tolist()
non_task_initial = pd.read_csv(base_url+'cnn_full.csv')['text'].values.tolist()

task = task_initial
n = int(len(task) / 3)
task_one = task[:n]
task_two = task[n:2*n]
task_three = task[2*n:]
non_task_one = non_task_initial[:n]
non_task_two = non_task_initial[n:2*n]
non_task_three = non_task_initial[2*n:len(task)]

# Creating pairs of data for siamese training => label 1 if pairs from same class otherwise 0
df2 = pd.DataFrame(columns=['text1', 'text2', 'label'])

for idx, data in notebook.tqdm(enumerate(task_one)):
  data1 = data
  data2 = task_two[idx]
  data3 = non_task_one[idx]
  df2.loc[len(df2)] = [data1, data2, 1]
  df2.loc[len(df2)] = [data1, data3, 0]

for idx, data in notebook.tqdm(enumerate(non_task_two)):
  data1 = data
  data2 = non_task_three[idx]
  data3 = task_three[idx]
  df2.loc[len(df2)] = [data1, data2, 1]
  df2.loc[len(df2)] = [data1, data3, 0]


0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [4]:
X_train, X_val, y_train, y_val = train_test_split(df2[['text1', 'text2']], df2['label'], test_size=0.2, random_state=0)
print(X_train.shape, X_val.shape, y_train.shape, y_val.shape)

(6889, 2) (1723, 2) (6889,) (1723,)


In [5]:
X_train['text'] = X_train[['text1', 'text2']].apply(lambda x: str(x[0])+" "+str(x[1]), axis=1)

In [6]:
# load json and create model
json_file = open(base_url+'siamesemodel-contrastive-loss.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
embedding_model = tf.keras.models.model_from_json(loaded_model_json)
# load weights into new model
embedding_model.load_weights(base_url+"siamesemodel-contrastive-loss.h5")
print("Loaded model from disk")

Loaded model from disk


In [7]:
test_df = pd.read_csv(base_url+'AG_test.csv', header=None)
test_df[:3]

Unnamed: 0,0,1
0,Unions representing workers at Turner Newall...,3
1,"SPACE.com - TORONTO, Canada -- A second\team o...",4
2,AP - A company founded by a chemistry research...,4


In [8]:
train_df = pd.read_csv(base_url+'AG_train.csv')
train_df[:3]

Unnamed: 0,text,label
0,About 200 anti-war protesters held a sombre me...,1
1,"The meeting, which was initially due to featur...",1
2,Maybe the long overseas flight will take somet...,2


In [9]:
domain_df = pd.read_csv(base_url+'cnn_full.csv')
domain_df[:3]

Unnamed: 0,text,label
0,"(CNN)Right now, there's a shortage of truck d...",0
1,One solution to the problem is autonomous truc...,0
2,Among them is San Diego-based TuSimple.Founded...,0


In [10]:
t = Tokenizer()
t.fit_on_texts(X_train['text'].values)

def text_to_vector(text):
  vector = t.texts_to_sequences([text])
  vector = pad_sequences(vector,maxlen=200)
  return vector

def get_distance(text1, text2):
  vec1 = text_to_vector(text1)
  vec2 = text_to_vector(text2)
  prediction = embedding_model.predict([vec1, vec2])
  return prediction[0][0].item()

def knn_selection(query_vector, data_vectors, k):
    distances = [get_distance(query_vector, data_vector) for data_vector in data_vectors]
    sorted_indices = np.argsort(distances)
    top_indices = sorted_indices[:k]
    return top_indices

In [11]:
train_1 = train_df[train_df['label'] == 1]['text']
train_2 = train_df[train_df['label'] == 2]['text']
train_3 = train_df[train_df['label'] == 3]['text']
train_4 = train_df[train_df['label'] == 4]['text']
domain_data = domain_df['text']

print(train_1.sample(3))
print(train_2.sample(3))
print(train_3.sample(3))
print(train_4.sample(3))
print(domain_data.sample(3))

36736    ELYRIA, Ohio - President Bush and Sen. John Ke...
11343    The government ordered the Maritime Self-Defen...
84811    The Greek government appealed for calm today, ...
Name: text, dtype: object
124351     AP - Tiger Woods already lost out on the majors.
24880     Center fielder Ellen Estes scored three goals ...
58651     A top LSU official says the school will do wha...
Name: text, dtype: object
50453    MONTREAL : A consumer group in Quebec has sued...
29573    Wal-Mart Stores Inc., the world #39;s largest ...
66617    Net sales for the fiscal Q1 rose 17.1 percent ...
Name: text, dtype: object
146047    NewsFactor - Making illegal copies of computer...
83692     of America (MPAA), along with its audio altern...
81707     CAPE CANAVERAL, Fla. -- NASA #39;s Swift satel...
Name: text, dtype: object
105461    The country is also offering generous subsidie...
125791    Curry's teammate, Draymond Green, who has been...
117894    United goalkeeper David de Gea was then twice ...
Na

In [12]:
'''
1. train dataset에서 {num_samples_train}개의 data sampling
2. 각 train data마다 domain dataset에서 {num_samples_domain}개의 data sampling
3. {num_samples_domain}개의 domain data 중에서 train data와 가장 유사한 top-k를 retrieve하여 train dataset에 추가
'''

k = 25
num_samples_train = 5
num_samples_domain = 1000


In [13]:
for train_data in train_1.sample(num_samples_train):
  top_indices = knn_selection(train_data, domain_data.sample(n=num_samples_domain), k)
  top_texts = domain_data[top_indices]
  print(top_texts)
  for text in top_texts:
    new_record = {'text': text, 'label': 1}
    train_df = train_df.append(new_record, ignore_index=True)

424    "Rio Tinto is in the process of terminating al...
599    The government hiked payments by £20 ($26) a w...
693    That was well above the estimated 3% increase ...
241                                           2 economy.
542    Correction: A previous version of this story m...
993    The company now says its undergoing unrelated ...
467    HospitalityHyatt (H) is halting development in...
597                                         Keep trying.
871    For the most part, they have been used for sho...
728    But the next day, they were told by a nurse mo...
305    A rapid rise in downloads Read MoreDuring the ...
324    The most tech-savvy and privacy-conscious user...
731    "We cannot express what [we] feel [at] that ti...
142       "The name of the game will be differentiation.
767    "Nowhere to goAs international executives jump...
327    Natalia Krapiva, a lawyer at the digital right...
33     That adds another layer of safety for the vehi...
743     Previously, the require

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,

438    JPMorgan cited "compliance with directives by ...
247    However, energy veterans cautioned against rea...
447    Restaurant Brands International, which owns th...
491    "The company "is strongly against any acts of ...
115     (CNN)In a Hong Kong warehouse, a swarm of aut...
922    Most automakers rely on radar for adaptive cru...
619    For people who have taken on new debts and fal...
80     Its developer, Bioservo, says it can increase ...
466    The company said it would "assess additional o...
89     Featuring cameras, microphones and sensors, th...
185    ""We are Russian people, thinking and smart, a...
819    Essaye argues that Fed rate hikes and a slowin...
691    Hong Kong (CNN Business)China's economy starte...
938    Then he checked his owner's manual, which call...
282    ""We hope to be back in stock in early April b...
712    "But he said the impact on global commodity pr...
257    'Complicated history'Indian officials said the...
981    The announcement comes a

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,

384    In response, the company has also moved to lim...
115     (CNN)In a Hong Kong warehouse, a swarm of aut...
195    "Scott said Fox would provide further updates ...
201    "We are horrified to learn that our fellow cor...
485    "WarnerMedia is also pausing all new business ...
609    "I can't afford £100 ($131) for oil, so why wo...
127    Hide Caption 4 of 8 Photos: The robots running...
733    Her story, and others like it, shine a light o...
421    "We are all deeply troubled by the invasion of...
810    And it inverted in early 2000 right before the...
33     That adds another layer of safety for the vehi...
664    It has since brought back the option and is no...
895    He said Ford had already offered buyers of its...
175    More than 2,500 civilians have died in Mariupo...
657    A third-party company collected the dirty cups...
911    More recently, he's said that self-driving wil...
577    A government spokesperson told CNN Business th...
999    New Delhi (CNN Business)

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,

676    Customers can simply give their cups to barist...
145                                        They're not."
625    "We're locking it up into credit cards so we c...
166    "Our new tariffs will further isolate the Russ...
800    These lockdowns come just months after China s...
295    A box of 20 tablets is priced at $12.95 on the...
583    "A lot of people are struggling with debt alre...
565    "You just literally watch [the meter] go down ...
861    But these aren't building blocks, and the cran...
356    "Clegg's internal post on Sunday doubled down ...
873    For these reasons, Piconi says that while batt...
897    But Deep said Ford will be able to install the...
281    Its website currently shows a message saying, ...
400    "Our first priority over the past week has bee...
560    "Or do [you] buy food so they can eat and they...
465    Yum Brands (YUM), which has 1,000 KFC and Pizz...
219    European banks have over $84 billion total cla...
773    And, if unemployed, they

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,

692    But as Covid cases in the country spike, keepi...
259    "We recognize that India has a complicated his...
488    "Estée Lauder Companies said March 7 that it w...
94     Designed to keep human workers out of harm's w...
939    Tesla introduced the feature in 2015 and based...
949    It recommends against using adaptive cruise co...
292    He said consumers should follow the CDC's guid...
336    If only a small minority of Russians end up em...
906    They've filed complaints with the National Hig...
566    'There's just nothing left to give'The worst i...
225    New York (CNN Business)US oil prices briefly t...
486    RetailCrocs (CROX) said March 9 that it will "...
923    (Tesla calls its adaptive cruise control, "tra...
150                    The BRICS could become the TICKS.
752    "The representative also pointed to emerging o...
582    "We're appalled when the main measure the gove...
861    But these aren't building blocks, and the cran...
918    "I wanted to make sure I

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,

In [14]:
train_df.tail(k*num_samples_train)

Unnamed: 0,text,label
146606,"""Rio Tinto is in the process of terminating al...",1
146607,The government hiked payments by £20 ($26) a w...,1
146608,That was well above the estimated 3% increase ...,1
146609,2 economy.,1
146610,Correction: A previous version of this story m...,1
...,...,...
146726,Telecoms firm ZTE (ZTCOF) lost 7%.,1
146727,"""Clegg's internal post on Sunday doubled down ...",1
146728,officials have declined to say if India would ...,1
146729,"""Put them together and you'll get an A+ on you...",1


In [15]:
for train_data in train_2.sample(num_samples_train):
  top_indices = knn_selection(train_data, domain_data.sample(n=num_samples_domain), k)
  top_texts = domain_data[top_indices]
  print(top_texts)
  for text in top_texts:
    new_record = {'text': text, 'label': 2}
    train_df = train_df.append(new_record, ignore_index=True)

875    During off-peak periods, a turbine pumps water...
764    Most of the senior expats in Hong Kong were on...
526    Companies find themselves caught between helpi...
758    "I suspect there's a lot of international bank...
252    India has not condemned the invasion of Ukrain...
894    "We're doing this as a way to get our customer...
67     The indoor farm is one of the biggest in Europ...
619    For people who have taken on new debts and fal...
223    "For Russia, the main cost is being locked out...
476    General Electric (GE) suspended most of its op...
89     Featuring cameras, microphones and sensors, th...
305    A rapid rise in downloads Read MoreDuring the ...
216    JPMorgan estimates that it had about $40 billi...
292    He said consumers should follow the CDC's guid...
10     Hide Caption 4 of 13 Photos: Reaching speeds o...
353    Russia has separately moved to block Facebook ...
878    "Simple and elegant"Since Energy Vault establi...
823          Rates fall when in

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,

664    It has since brought back the option and is no...
120    When an order is sent to the warehouse, the bo...
266    Japanese authorities ordered crypto exchanges ...
863    When power demand is low, the crane uses surpl...
282    ""We hope to be back in stock in early April b...
423    An Exxon subsidiary has a 30% share, while Ros...
460    But PepsiCo will continue to sell some of its ...
44     By focusing on the "middle mile," rather than ...
572    There's no way the numbers add up now," Lucy B...
459    "We are working hard to help keep food availab...
142       "The name of the game will be differentiation.
710    Earlier this month, Premier Li set China's eco...
594    Volunteers at Cooking Champions in London crea...
520    "— Rishi Iyengar, Michelle Toh, Diksha Madhok,...
365     Here's a look at the major corporate departures.
905    Morris is one of hundreds of Tesla owners clai...
416    "PricewaterhouseCoopers (PwC) is also planning...
223    "For Russia, the main co

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,

828    "Of course there are many reasons to be concer...
537    "The digital Iron Curtain: How Russia's intern...
943    Previously, Teslas included a basic cruise con...
447    Restaurant Brands International, which owns th...
803    New York (CNN Business)Surging oil and gas pri...
813    "The risks of a recession are building but not...
590    Cooking Champions, an organization which cater...
531    Cogent Communications CEO Dave SchaefferAccord...
203    We wish Ben a quick recovery and call for utmo...
9      It has a top speed of 600 km per hour -- curre...
472    The 26 Hilton hotels in Russia are managed or ...
292    He said consumers should follow the CDC's guid...
992    But Virgin Galactic has yet to follow up that ...
849    "Read More"Addressing a challenge as big as cl...
931    "Cruise control is completely unusable on two-...
664    It has since brought back the option and is no...
397    Sony "joins the global community in calling fo...
505    Procter & Gamble (PG) CE

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,

194    "This is a stark reminder for all journalists ...
667    She added that the chain is also experimenting...
705    All businesses — apart from those deemed essen...
794    So far the port continues to operate.Shares of...
887    Credit: Energy Vault\nThis year, Energy Vault ...
398    The streaming service removed all content from...
168     (CNN Business)A woman holding a sign reading ...
71     They're not trying to port an autonomy system ...
174    Satellite images show widespread destruction f...
146    Forget the BRICS and look at TICKS or MIST?To ...
934    The automaker told Morris, according to screen...
225    New York (CNN Business)US oil prices briefly t...
363     (CNN Business)Dozens of the world's biggest c...
231    "You're seeing some vicious selling," said Mat...
633    "But it is also this ubiquitous symbol of a th...
311    But perhaps the fastest-growing messaging app ...
486    RetailCrocs (CROX) said March 9 that it will "...
720    But once she was there, 

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,



  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)


355    The statement added that Meta has "no quarrel ...
113                 "But this is an exciting technology.
702    "The recent spread of the coronavirus in many ...
157    "The emerging markets landscape has been chang...
227    It's the first time oil has traded below $100 ...
546    Begum's energy bills for her one-bedroom apart...
258    officials have declined to say if India would ...
878    "Simple and elegant"Since Energy Vault establi...
338    "The concern, of course, is that the majority ...
863    When power demand is low, the crane uses surpl...
511    FedEx said it suspended operations to "support...
459    "We are working hard to help keep food availab...
665    "We're testing an incentive on the personal cu...
170    Do not believe propaganda they tell you lies h...
284    "The big run started on February 23 through Fe...
612    "With poverty there is that loss of dignity if...
448    That doesn't necessarily mean that Burger King...
221    A second risk is that a 

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,

In [16]:
train_df.tail(k*num_samples_train)

Unnamed: 0,text,label
146731,"During off-peak periods, a turbine pumps water...",2
146732,Most of the senior expats in Hong Kong were on...,2
146733,Companies find themselves caught between helpi...,2
146734,"""I suspect there's a lot of international bank...",2
146735,India has not condemned the invasion of Ukrain...,2
...,...,...
146851,"""[They] can be complex if your digital literac...",2
146852,Its most recent foreign currency default came ...,2
146853,"""H&M (HMRZF) will pause all sales in Russia.In...",2
146854,"Tesla CEO Elon Musk said in 2015, 2016, 2017, ...",2


In [17]:
for train_data in train_3.sample(num_samples_train):
  top_indices = knn_selection(train_data, domain_data.sample(n=num_samples_domain), k)
  top_texts = domain_data[top_indices]
  print(top_texts)
  for text in top_texts:
    new_record = {'text': text, 'label': 3}
    train_df = train_df.append(new_record, ignore_index=True)

494    "The clothing giant's stores in Ukraine are al...
385    re:Store is one of the largest Apple resellers...
841    Cook would be the first Black woman to serve a...
818    "Russia/Ukraine is only pulling forward the na...
565    "You just literally watch [the meter] go down ...
912    (Beta is Silicon Valley lingo for an unfinishe...
171    "Russians against war," the last line of the s...
432    In addition, all Visa cards worldwide "will no...
274    The danger is that if radiation exposure occur...
114    "This story has been updated to correct the co...
462    "By continuing to operate, we will also contin...
53     Lu says that 7,000 have been reserved in the U...
722    Her place of work was her home and "my employe...
305    A rapid rise in downloads Read MoreDuring the ...
924    Most autonomous vehicle experts favor sensor f...
51     That means you free up that human driver to do...
656    For baristas, the process was straightforward ...
638               He called tha

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,

674        "We've got mock stores set up," said Landers.
688    Employees at a handful of other stores have si...
824    So traders are clearly still finding US Treasu...
664    It has since brought back the option and is no...
375                                          Petersburg.
411    "In light of the escalating war, the EY global...
134                  Watch the video for the full story.
163    It will ban the export of luxury goods to Russ...
443    Citi noted that pulling its operations "will t...
689    "We know that even the most ardent of sustaina...
913    "They're burning up a lot of good will with no...
245    A ceasefire could ease fears about a prolonged...
847    "We believe that there are sufficient grounds ...
659    You don't have to remember to bring your own r...
959    Morris said he would have kept the Model Y if ...
532    The company, which is based in Washington D.C....
75     "In the US every year there are about 5,000 fa...
115     (CNN)In a Hong Kong war

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,



  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)


376    "AviationBoeing (BA) said it would suspend sup...
346    "We're at the beginning of a J-curve," Meinrat...
635    When they are thrown away, the cups end up in ...
841    Cook would be the first Black woman to serve a...
557    Davina Mathurin, project officer for The Boile...
398    The streaming service removed all content from...
969    Rogozin has long been known to share outlandis...
888    It has also signed deals worth up to $880 mill...
963    New York (CNN Business)NASA said Monday that N...
263    "British Foreign Minister Liz Truss also said ...
123    Featuring cameras, microphones and sensors, th...
521     (CNN Business)Big tech platforms have joined ...
244    "Oil traders are also watching for development...
278    Another listing for a box of IOSAT 130 mg pill...
696    China sets lowest economic growth target in de...
785         "Foxconn has two major campuses in Shenzhen.
875    During off-peak periods, a turbine pumps water...
936    It also explained that c

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,

161    "The measure will hit Russia's oil majors Rosn...
341    Normalizing censorship-resistant tech  Some di...
82     Exoskeletons are an external device that suppo...
522    In the past week, however, the severing of Rus...
712    "But he said the impact on global commodity pr...
354    Read MoreThe Ukraine-specific policy on hate s...
775    'Zero income'While big international firms may...
484    'The Batman' pulled from RussiaWarnerMedia sai...
122    Hide Caption 2 of 8 Photos: The robots running...
177    JUST WATCHEDAmerican in Moscow reveals how Rus...
395    "Given the current situation, we have no plans...
887    Credit: Energy Vault\nThis year, Energy Vault ...
52     The results of such tests will indicate whethe...
666    "We are also going to be testing a disposable ...
405    "While we know this is the right decision, it ...
765    "Over the last, let's say, 10, 15 years, most ...
377    "Airbus (EADSF) followed Boeing with a similar...
886    Energy Vault's resilienc

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,

810    And it inverted in early 2000 right before the...
37     Its plans are dependent on state legislation, ...
730    One worker, not pictured here, said she was no...
964    The space agency sought to reaffirm Monday tha...
663    Early in the pandemic, when people feared that...
482    That includes Marvel's "Doctor Strange in the ...
368    The company has plants in St. Petersburg, Elab...
729    In recent weeks, dozens of domestic workers ha...
228    "This is one hell of a correction," said Tom K...
125    "Handle" is made for the warehouse and equippe...
400    "Our first priority over the past week has bee...
608    She is a member of Covid Realities, a research...
767    "Nowhere to goAs international executives jump...
792    Public transportation, including subways and b...
588    Paychecks can't keep upAt a north London churc...
598    "Moseley receives Universal Credit — a benefit...
112    "If you look at exoskeletons, this is just one...
636    Some might be recycled, 

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,

In [18]:
train_df.tail(k*num_samples_train)

Unnamed: 0,text,label
146856,"""The clothing giant's stores in Ukraine are al...",3
146857,re:Store is one of the largest Apple resellers...,3
146858,Cook would be the first Black woman to serve a...,3
146859,"""Russia/Ukraine is only pulling forward the na...",3
146860,"""You just literally watch [the meter] go down ...",3
...,...,...
146976,"""It's crucial for Starbucks' mobile order and ...",3
146977,"""Unauthorized payments to targets under sancti...",3
146978,"Davina Mathurin, project officer for The Boile...",3
146979,"""Any customer we know of who is participating ...",3


In [19]:
for train_data in train_4.sample(num_samples_train):
  top_indices = knn_selection(train_data, domain_data.sample(n=num_samples_domain), k)
  top_texts = domain_data[top_indices]
  print(top_texts)
  for text in top_texts:
    new_record = {'text': text, 'label': 4}
    train_df = train_df.append(new_record, ignore_index=True)

994    The company completed the first-ever all-civil...
411    "In light of the escalating war, the EY global...
744    Hong Kong is sticking to zero-Covid, no matter...
978    "So, we are planning to continue operations as...
456    "McDonald's is temporarily closing its Russian...
171    "Russians against war," the last line of the s...
566    'There's just nothing left to give'The worst i...
144    "It's always strange to say that Argentina and...
797    BYD (BYDDF), China's largest electric car manu...
827    "The recession drumbeat is gaining in volume,"...
89     Featuring cameras, microphones and sensors, th...
289    It also happened in the same year when the Haw...
997    The engines Blue Origin plans to use for New G...
985    After liftoff, the rocket will tear past the s...
521     (CNN Business)Big tech platforms have joined ...
142       "The name of the game will be differentiation.
900    Last year General Motors was forced to tempora...
227    It's the first time oil 

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,

496    In addition to pausing its retail and manufact...
716    But Russia's invasion of Ukraine has put their...
497    The company will continue to pay them, at leas...
641    Starbucks is also planning, by the end of next...
108    "By using tools such as the Ironhand we are re...
301    But despite Putin's efforts to clamp down on s...
367    The American automaker has a 50% stake in Ford...
675    "We have different versions of the drive-thru ...
288    The plant's electrical system was reportedly d...
756    BASF (BASFY), a German chemicals giant, recent...
349    The spokesperson affirmed that the restriction...
268    "We decided to make an announcement to keep th...
468    Hyatt said it continues to "evaluate hotel ope...
964    The space agency sought to reaffirm Monday tha...
498    In a statement, the company said Russia accoun...
925    "It's like pairing one A student with another ...
599    The government hiked payments by £20 ($26) a w...
629    New York (CNN Business)S

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,

575           Feeling hungry and not asking for a snack.
278    Another listing for a box of IOSAT 130 mg pill...
265    "This is not a fight we have created," he told...
968    But that reliance ended after SpaceX debuted i...
430    "Visa (V) said it is suspending all of its ope...
295    A box of 20 tablets is priced at $12.95 on the...
566    'There's just nothing left to give'The worst i...
193    "The safety of our entire our entire team of j...
979    "Russian space chief says Russia will no longe...
808    The yield curve inverted in 2019 before the 20...
224    But sanctions have done that anyway," wrote an...
628                       "Hardship is harder," he said.
637       "Eliminating the disposable cup," Kobori said.
719    Hong Kong (CNN Business)Maria was just about t...
997    The engines Blue Origin plans to use for New G...
393    Netflix (NFLX) said it will be suspending its ...
562                           Demand was high, she said.
412    "Consulting and accounti

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,

265    "This is not a fight we have created," he told...
903    Washington, DC (CNN)Ben Morris bought his Tesl...
750    "From the start of the pandemic through the en...
471       The company does not own any hotels in Russia.
501    Mondelez (MDLZ) said it would scale back all n...
126           It can lift boxes weighing over 30 pounds.
945    "Traffic-Aware Cruise Control may occasionally...
997    The engines Blue Origin plans to use for New G...
516    The group is also suspending all future busine...
189    CNN obtained the video from a live feed of Rus...
409    ""EY in Russia is a team of 4,700 professional...
806    As of Friday, the difference was just 0.25%, w...
204    "Hall's injury comes one day after Brent Renau...
329    But it hasn't been very successful, she said, ...
597                                         Keep trying.
133    Read More"Everyone is looking for an automatio...
615    The charity distributed 48,000 pairs around th...
897    But Deep said Ford will 

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,



  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)


404    ""We will be providing support to our Russian ...
553    Read MoreAverage worker pay suffered its bigge...
543    The company is headquartered in Monroe, Louisi...
481    "The entertainment giant had multiple films se...
857    Shell's net-zero target was also not reflected...
73     Removing the human element on long trucking ro...
814       That could cause a slowdown in the job market.
419    "We are also committed to working with our col...
445    The company said it's also "forfeiting all fin...
781    "We don't know how long these restrictions wil...
172    The outlet's content is tightly controlled by ...
455    "We see a clear distinction between the action...
960                 He still loves his Model X, for now.
193    "The safety of our entire our entire team of j...
433    American Express (AXP) said in a statement tha...
668                                That's simple enough.
367    The American automaker has a 50% stake in Ford...
713    "The recent acceleration

  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record, ignore_index=True)
  train_df = train_df.append(new_record,

In [20]:
train_df.tail(k*num_samples_train)

Unnamed: 0,text,label
146981,The company completed the first-ever all-civil...,4
146982,"""In light of the escalating war, the EY global...",4
146983,"Hong Kong is sticking to zero-Covid, no matter...",4
146984,"""So, we are planning to continue operations as...",4
146985,"""McDonald's is temporarily closing its Russian...",4
...,...,...
147101,But that reliance ended after SpaceX debuted i...,4
147102,AutosFord (F) announced it was suspending its ...,4
147103,"""We have thoroughly evaluated internal and ext...",4
147104,The steel tower is a giant mechanical energy s...,4


# 3. TAPT (AG News classification)

In [21]:
! pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [22]:
import torch
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

For augmented train data

In [23]:
input_texts = train_df['text'][-1500:]
labels = train_df['label'][-1500:]

tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
encoded_inputs = tokenizer(input_texts.tolist(), padding=True, truncation=True, return_tensors='pt')
labels = torch.tensor(labels.tolist()) - 1

In [24]:
train_inputs, val_inputs, train_labels, val_labels = train_test_split(encoded_inputs['input_ids'],
                                                                    labels,
                                                                    random_state=42,
                                                                    test_size=0.2)
train_dataset = TensorDataset(train_inputs, train_labels)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [None]:
model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=4)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
loss_fn = torch.nn.CrossEntropyLoss()

epochs = 10
for epoch in range(epochs):
    total_loss = 0
    model.train()

    for batch in train_dataloader:
        batch_inputs, batch_labels = batch
        outputs = model(input_ids=batch_inputs, labels=batch_labels)
        loss = outputs.loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        print('.', end='')

    model.eval()

    with torch.no_grad():
        val_outputs = model(input_ids=val_inputs, labels=val_labels)
        val_loss = val_outputs.loss
        val_accuracy = (val_outputs.logits.argmax(dim=1) == val_labels).float().mean()

    print(f'Epoch {epoch + 1}/{epochs}')
    print(f'Training loss: {total_loss / len(train_dataloader)}')
    print(f'Validation loss: {val_loss}')
    print(f'Validation accuracy: {val_accuracy}')

model.save_pretrained('roberta_classification_model')


Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.out_proj.weight', 'classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.bias']
You should pr

In [None]:
test_texts = test_df[0][:2000]
test_labels = test_df[1][:2000]

encoded_test_inputs = tokenizer(test_texts.tolist(), padding=True, truncation=True, return_tensors='pt')
test_labels = torch.tensor(test_labels.tolist()) - 1

test_dataset = TensorDataset(encoded_test_inputs['input_ids'], test_labels)
test_dataloader = DataLoader(test_dataset, batch_size=50, shuffle=False)

In [None]:
from sklearn.metrics import classification_report

predictions = []
labels = []

with torch.no_grad():
    for batch in test_dataloader:
        inputs = {
            'input_ids': batch[0]
        }
        outputs = model(**inputs)
        logits = outputs.logits
        predictions.extend(torch.argmax(logits, dim=1).tolist())
        labels.extend(batch[1].tolist())

classification_rep = classification_report(labels, predictions, target_names=['class1', 'class2', 'class3', 'class4'])
print(classification_rep)

For non-augmented train data

In [13]:
input_texts = train_df['text'][-1000:]
labels = train_df['label'][-1000:]

tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
encoded_inputs = tokenizer(input_texts.tolist(), padding=True, truncation=True, return_tensors='pt')
labels = torch.tensor(labels.tolist()) - 1

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

In [14]:
train_inputs, val_inputs, train_labels, val_labels = train_test_split(encoded_inputs['input_ids'],
                                                                    labels,
                                                                    random_state=42,
                                                                    test_size=0.2)
train_dataset = TensorDataset(train_inputs, train_labels)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [15]:
model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=4)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
loss_fn = torch.nn.CrossEntropyLoss()

epochs = 10
for epoch in range(epochs):
    total_loss = 0
    model.train()

    for batch in train_dataloader:
        batch_inputs, batch_labels = batch
        outputs = model(input_ids=batch_inputs, labels=batch_labels)
        loss = outputs.loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        print('.', end='')

    model.eval()

    with torch.no_grad():
        val_outputs = model(input_ids=val_inputs, labels=val_labels)
        val_loss = val_outputs.loss
        val_accuracy = (val_outputs.logits.argmax(dim=1) == val_labels).float().mean()

    print(f'Epoch {epoch + 1}/{epochs}')
    print(f'Training loss: {total_loss / len(train_dataloader)}')
    print(f'Validation loss: {val_loss}')
    print(f'Validation accuracy: {val_accuracy}')

model.save_pretrained('roberta_classification_model')


Downloading model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight']
You should pr

.........................Epoch 1/10
Training loss: 1.3443658256530762
Validation loss: 1.1231213808059692
Validation accuracy: 0.5950000286102295
.........................Epoch 2/10
Training loss: 0.8393831539154053
Validation loss: 0.5801771879196167
Validation accuracy: 0.7900000214576721
.........................Epoch 3/10
Training loss: 0.5553411138057709
Validation loss: 0.5203447341918945
Validation accuracy: 0.8299999833106995
.........................Epoch 4/10
Training loss: 0.4315053939819336
Validation loss: 0.49683961272239685
Validation accuracy: 0.824999988079071
.........................Epoch 5/10
Training loss: 0.3395357257127762
Validation loss: 0.4552464187145233
Validation accuracy: 0.8450000286102295
.........................Epoch 6/10
Training loss: 0.27879302829504016
Validation loss: 0.5199941396713257
Validation accuracy: 0.8399999737739563
.........................Epoch 7/10
Training loss: 0.23258209884166717
Validation loss: 0.4807156026363373
Validation accur

In [17]:
test_texts = test_df[0][:2000]
test_labels = test_df[1][:2000]

encoded_test_inputs = tokenizer(test_texts.tolist(), padding=True, truncation=True, return_tensors='pt')
test_labels = torch.tensor(test_labels.tolist()) - 1

test_dataset = TensorDataset(encoded_test_inputs['input_ids'], test_labels)
test_dataloader = DataLoader(test_dataset, batch_size=50, shuffle=False)

In [23]:
from sklearn.metrics import classification_report

predictions = []
labels = []

with torch.no_grad():
    for batch in test_dataloader:
        inputs = {
            'input_ids': batch[0]
        }
        outputs = model(**inputs)
        logits = outputs.logits
        predictions.extend(torch.argmax(logits, dim=1).tolist())
        labels.extend(batch[1].tolist())

classification_rep = classification_report(labels, predictions, target_names=['class1', 'class2', 'class3', 'class4'])
print(classification_rep)

              precision    recall  f1-score   support

      class1       0.84      0.79      0.81       470
      class2       0.92      0.85      0.89       604
      class3       0.74      0.77      0.75       405
      class4       0.71      0.79      0.75       521

    accuracy                           0.80      2000
   macro avg       0.80      0.80      0.80      2000
weighted avg       0.81      0.80      0.81      2000

