In [1]:
import argparse
import sys
from sentence_transformers import SentenceTransformer
from sentence_transformers_local import models, losses, SentenceTransformerSequential
from models.Transformers import SCCLBert
from learners.cluster import ClusterLearner

from dataloader.dataloader import augment_loader, augment_loader_split
from training import training

from utils.kmeans import get_kmeans_centers
# from utils.logger import setup_path
from utils.randomness import set_global_random_seed
import torch
import pandas as pd
import os
from torch import nn

In [2]:
MODEL_CLASS = {
    "distil": 'distilbert-base-nli-stsb-mean-tokens', 
    "robertabase": 'roberta-base-nli-stsb-mean-tokens',
    "robertalarge": 'roberta-large-nli-stsb-mean-tokens',
    "msmarco": 'distilroberta-base-msmarco-v2',
    "xlm": "xlm-r-distilroberta-base-paraphrase-v1",
    "bertlarge": 'bert-large-nli-stsb-mean-tokens',
    "bertbase": 'bert-base-nli-stsb-mean-tokens',
    "paraphrase": "paraphrase-mpnet-base-v2",
    "paraphrase-distil": "paraphrase-distilroberta-base-v2",
    "paraphrase-Tiny" : "paraphrase-TinyBERT-L6-v2"
}

parser = argparse.ArgumentParser()
# parser.add_argument('--gpuid', nargs="+", type=int, default=[0], help="The list of gpuid, ex:--gpuid 3 1. Negative value means cpu-only")
parser.add_argument('--seed', type=int, default=0, help="")
parser.add_argument('--print_freq', type=float, default=100, help="")  
parser.add_argument('--result_path', type=str, default='./results/')

parser.add_argument('--bert', type=str, default='paraphrase', help="")
#parser.add_argument('--bert', type=str, default='distil', help="")

parser.add_argument('--bert_model', type=str, default='bert-base-uncased', help="")
parser.add_argument('--note', type=str, default='_search_snippets_distil_lre-4_JSD', help="")

# Dataset
# stackoverflow/stackoverflow_true_text
parser.add_argument('--dataset', type=str, default='search_snippets', help="")
#parser.add_argument('--dataset', type=str, default='stackoverflow', help="")
# parser.add_argument('--data_path', type=str, default='./datasets/stackoverflow/')
parser.add_argument('--max_length', type=int, default=32)
parser.add_argument('--train_val_ratio', type=float, default= [0.9, 0.1])

# Data for train and test
# ###### AgNews
# parser.add_argument('--data_path', type=str, default='./datasets/')
parser.add_argument('--dataname', type=str, default='agnewsdataraw-8000', help="")
parser.add_argument('--dataname_val', type=str, default='agnewsdataraw-8000', help="")
parser.add_argument('--num_classes', type=int, default=4, help="")
# ####### SearchSnippets
parser.add_argument('--data_path', type=str, default='./datasets/augmented/contextual_20_2col_bert/')
# ## parser.add_argument('--dataname', type=str, default='train_search_snippets.csv', help="")
## parser.add_argument('--dataname_val', type=str, default='test_search_snippets.csv', help="")
# parser.add_argument('--dataname', type=str, default='search_snippets', help="")
# parser.add_argument('--dataname_val', type=str, default='search_snippets', help="")
# parser.add_argument('--num_classes', type=int, default=8, help="")
# # ###### StackOverFlow
# parser.add_argument('--data_path', type=str, default='./datasets/stackoverflow/')
# parser.add_argument('--dataname', type=str, default='stackoverflow', help="")
# parser.add_argument('--dataname_val', type=str, default='stackoverflow_', help="")
# parser.add_argument('--num_classes', type=int, default=20, help="")
# ###### Biomedical
# # parser.add_argument('--data_path', type=str, default='./datasets/biomedical/')
# parser.add_argument('--dataname', type=str, default='biomedical', help="")
# parser.add_argument('--dataname_val', type=str, default='biomedical', help="")
# parser.add_argument('--num_classes', type=int, default=20, help="")
# ######## Tweet
# parser.add_argument('--data_path', type=str, default='./datasets/')
# parser.add_argument('--dataname', type=str, default='tweet_remap_label', help="")
# parser.add_argument('--dataname_val', type=str, default='tweet_remap_label', help="")
# parser.add_argument('--num_classes', type=int, default=89, help="")
# ######## GoogleNewsTS
# parser.add_argument('--data_path', type=str, default='./datasets/')
# parser.add_argument('--dataname', type=str, default='TS', help="")
# parser.add_argument('--dataname_val', type=str, default='TS', help="")
# parser.add_argument('--num_classes', type=int, default=152, help="")
# ######## GoogleNewsT
# parser.add_argument('--data_path', type=str, default='./datasets/')
# parser.add_argument('--dataname', type=str, default='T', help="")
# parser.add_argument('--dataname_val', type=str, default='T', help="")
# parser.add_argument('--num_classes', type=int, default=152, help="")
# ######## GoogleNewsS
# parser.add_argument('--data_path', type=str, default='./datasets/')
# parser.add_argument('--dataname', type=str, default='S', help="")
# parser.add_argument('--dataname_val', type=str, default='S', help="")
# parser.add_argument('--num_classes', type=int, default=152, help="")

# Learning parameters
parser.add_argument('--lr', type=float, default=1e-6, help="") #learning rate
parser.add_argument('--lr_scale', type=int, default=100, help="")
parser.add_argument('--max_iter', type=int, default=3000)
parser.add_argument('--batch_size', type=int, default=256) #batch size

# CNN Setting
#parser.add_argument('--out_channels', type=int, default=768)
#parser.add_argument('--use_cnn', type=str, default='cnn_1')
#parser.add_argument('--use_cnn', type=str, default='cnn_3')
#parser.add_argument('--use_cnn', type=str, default='cnn_5')
#parser.add_argument('--use_cnn', type=str, default='cnn_7')
#parser.add_argument('--use_cnn', type=str, default='cnn_cat')
#parser.add_argument('--use_cnn', type=str, default='cnn_avg')

# Contrastive learning
parser.add_argument('--use_head', type=bool, default=False)
parser.add_argument('--use_normalize', type=bool, default=False)

parser.add_argument('--weighted_local', type=bool, default=False, help="")
#parser.add_argument('--normalize_method', type=str, default='inverse_prob', help="")
parser.add_argument('--normalize_method', type=str, default='none', help="")

parser.add_argument('--contrastive_local_scale', type=float, default=0.002)  #  unused!!!
parser.add_argument('--contrastive_global_scale', type=float, default=0.008) #  unused!!!
parser.add_argument('--temperature', type=float, default=0.5, help="temperature required by contrastive loss")
parser.add_argument('--base_temperature', type=float, default=0.1, help="temperature required by contrastive loss")

# Clustering
parser.add_argument('--clustering_scale', type=float, default=0.02) #scale of clustering loss
parser.add_argument('--use_perturbation', action='store_true', help="")
parser.add_argument('--alpha', type=float, default=1)

args = parser.parse_args(args=[])
# args.use_gpu = args.gpuid[0] >= 0
args.resPath = None
args.tensorboard = None

In [3]:
# !pip install sentence-transformers
# !pip install torch
# !pip install captum
# !pip install protobuf==3.19.6
# !pip install tensorboardX
# !pip install pandas

In [4]:
# !pip install torch
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]='4'

# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print(torch.cuda.device_count())

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cuda
1
Tesla V100-SXM2-32GB-LS
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [5]:
import timeit
from datetime import datetime,timezone
start = timeit.default_timer()
now_utc = datetime.now(timezone.utc)
print('Time UTC:', now_utc)

# resPath, tensorboard = setup_path(args)
# args.resPath, args.tensorboard = resPath, tensorboard
set_global_random_seed(args.seed)

# Dataset loader
train_loader = augment_loader(args)

# torch.cuda.set_device(args.gpuid[0])
# torch.cuda.set_device(device)

# Initialize cluster centers
# by performing k-means after getting embeddings from Sentence-BERT with mean-pooling(defualt)
sbert = SentenceTransformer(MODEL_CLASS[args.bert])
cluster_centers = get_kmeans_centers(sbert, train_loader, args.num_classes) 



# Model
# 1. Transformer model 
# use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
# word_embedding_model = models.Transformer(MODEL_CLASS[args.bert])

word_embedding_model = models.Transformer('sentence-transformers/paraphrase-mpnet-base-v2')
# word_embedding_model = models.Transformer('sentence-transformers/stanford-sentiment-treebank-roberta.2021-03-11')

# model = SentenceTransformer('distilbert-base-nli-mean-tokens')
dimension = word_embedding_model.get_word_embedding_dimension()
# word_embedding_model = torch.nn.DataParallel(word_embedding_model)


# 2. CNN model
# cnn = models.CNN(in_word_embedding_dimension = word_embedding_model.get_word_embedding_dimension(), 
#                  use_cnn = args.use_cnn, out_channels = word_embedding_model.get_word_embedding_dimension())

# 3. Pooling 
# pooling_model = models.Pooling(cnn.get_word_embedding_dimension(),
#                                pooling_mode_mean_tokens=True,
#                                pooling_mode_cls_token=False,
#                                pooling_mode_max_tokens=False)
pooling_model = models.Pooling(dimension,
                               pooling_mode_mean_tokens=True,
                               pooling_mode_cls_token=False,
                               pooling_mode_max_tokens=False, 
                               pooling_mode_weighted_tokens=False)

# 4. Feature extractor 
#feature_extractor = SentenceTransformerSequential(modules=[word_embedding_model, cnn, pooling_model])
feature_extractor = SentenceTransformerSequential(modules=[word_embedding_model, pooling_model], device = 'cuda')

# 5. main model
model = SCCLBert(feature_extractor, cluster_centers=cluster_centers, alpha = args.alpha, use_head = args.use_head)  


# Optimizer 
optimizer = torch.optim.Adam([
    {'params':word_embedding_model.parameters(), 'lr': args.lr*6},
#    {'params':cnn.parameters(), 'lr': args.lr*50},
    {'params':pooling_model.parameters()},
#    {'params':model.head.parameters(), 'lr': args.lr*args.lr_scale},
    {'params':model.cluster_centers, 'lr': args.lr*60}], lr=args.lr)
# # optimizer = torch.optim.Adam(lr=1e-4,params=model.parameters())
# optimizer = torch.optim.AdamW([
#     {'params':word_embedding_model.parameters(), 'lr': args.lr},
# #    {'params':cnn.parameters(), 'lr': args.lr*50},
#     {'params':pooling_model.parameters()},
# #    {'params':model.head.parameters(), 'lr': args.lr*args.lr_scale},
#     {'params':model.cluster_centers, 'lr': args.lr*20}], lr=args.lr)
# # optimizer = torch.optim.Adam(lr=1e-4,params=model.parameters())
print(optimizer)


# Set up the trainer    
learner = ClusterLearner(model, feature_extractor, optimizer, args.temperature, args.base_temperature,
                         args.contrastive_local_scale, args.contrastive_global_scale, args.clustering_scale, use_head = args.use_head, use_normalize = args.use_normalize)
# learner = torch.nn.DataParallel(learner)
learner = learner.cuda()

# split train - validation
if(args.train_val_ratio != -1):
    train_loader, val_loader = augment_loader_split(args)
    training(train_loader, learner, args, val_loader = val_loader)
# normal
else:
    training(train_loader, learner, args) 

Time UTC: 2023-12-13 14:17:19.925986+00:00
all_embeddings:(8000, 768), true_labels:8000, pred_labels:8000
true_labels tensor([1, 2, 0,  ..., 1, 3, 1])
pred_labels tensor([0, 3, 2,  ..., 0, 1, 0], dtype=torch.int32)
Iterations:69, Clustering ACC:0.835, centers:(4, 768)
initial_cluster_centers =  torch.Size([4, 768])
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 6e-06
    weight_decay: 0

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 1e-06
    weight_decay: 0

Parameter Group 2
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 5.9999999999999995e-05
    weight_decay: 0
)




train_sample 0.9 7200
val_sample 0.1 800

=3000/29=Iterations/Batches
------------- Evaluate Training Set -------------
------------- 29 batches -------------
all_pred 4
step: 0
[Representation] Clustering scores: {'NMI': 0.5775709856489576, 'ARI': 0.619267308579574, 'AMI': 0.5773798619634802}
[Representation] ACC: 0.8344
[Representation] ACC sklearn: 0.4003
[Model] Clustering scores: {'NMI': 0.5791156560696884, 'ARI': 0.6189847484637262, 'AMI': 0.578925190528274}
[Model] ACC: 0.8335
[Model] ACC sklearn: 0.0518
------------- Evaluate Validation Set -------------
------------- 4 batches -------------
all_pred 4
step: 0
[Representation] Clustering scores: {'NMI': 0.5708731478133698, 'ARI': 0.5903315203525511, 'AMI': 0.5691091356786355}
[Representation] ACC: 0.8200
[Representation] ACC sklearn: 0.2100
[Model] Clustering scores: {'NMI': 0.6024378985953256, 'ARI': 0.6398325713091625, 'AMI': 0.6008022443171119}
[Model] ACC: 0.8450
[Model] ACC sklearn: 0.0475
Time UTC: 2023-12-13 14:18:11.949

all_pred 4
step: 800
[Representation] Clustering scores: {'NMI': 0.6979692943936944, 'ARI': 0.7217461869513467, 'AMI': 0.6967303673388124}
[Representation] ACC: 0.8838
[Representation] ACC sklearn: 0.0250
[Model] Clustering scores: {'NMI': 0.6368480710834659, 'ARI': 0.6675972772886306, 'AMI': 0.6353511438908235}
[Model] ACC: 0.8550
[Model] ACC sklearn: 0.0475
Time UTC: 2023-12-13 14:50:17.811835+00:00
Current running time 1978.16 seconds
------------- Evaluate Training Set -------------
------------- 29 batches -------------
all_pred 4
step: 900
[Representation] Clustering scores: {'NMI': 0.6778519654722981, 'ARI': 0.7179895388372398, 'AMI': 0.6777065643819999}
[Representation] ACC: 0.8811
[Representation] ACC sklearn: 0.4701
[Model] Clustering scores: {'NMI': 0.6694998768431858, 'ARI': 0.7105143704731012, 'AMI': 0.6693506789642608}
[Model] ACC: 0.8775
[Model] ACC sklearn: 0.0428
------------- Evaluate Validation Set -------------
------------- 4 batches -------------
all_pred 4
step: 

------------- Evaluate Training Set -------------
------------- 29 batches -------------
all_pred 4
step: 1700
[Representation] Clustering scores: {'NMI': 0.7057040082048721, 'ARI': 0.7479968418970042, 'AMI': 0.7055711693840683}
[Representation] ACC: 0.8957
[Representation] ACC sklearn: 0.2419
[Model] Clustering scores: {'NMI': 0.7057040082048721, 'ARI': 0.7479968418970042, 'AMI': 0.7055711693840683}
[Model] ACC: 0.8957
[Model] ACC sklearn: 0.0372
------------- Evaluate Validation Set -------------
------------- 4 batches -------------
all_pred 4
step: 1700
[Representation] Clustering scores: {'NMI': 0.7203940678265698, 'ARI': 0.7645527329545073, 'AMI': 0.7192481319782431}
[Representation] ACC: 0.9038
[Representation] ACC sklearn: 0.4350
[Model] Clustering scores: {'NMI': 0.7203940678265698, 'ARI': 0.7645527329545073, 'AMI': 0.7192481319782431}
[Model] ACC: 0.9038
[Model] ACC sklearn: 0.0325
Time UTC: 2023-12-13 15:26:31.254433+00:00
Current running time 4151.6 seconds
------------- Ev

all_pred 4
step: 2500
[Representation] Clustering scores: {'NMI': 0.7275570545440082, 'ARI': 0.7682634609533558, 'AMI': 0.7264405210638069}
[Representation] ACC: 0.9050
[Representation] ACC sklearn: 0.4338
[Model] Clustering scores: {'NMI': 0.7275570545440082, 'ARI': 0.7682634609533558, 'AMI': 0.7264405210638069}
[Model] ACC: 0.9050
[Model] ACC sklearn: 0.0288
Time UTC: 2023-12-13 15:58:06.795809+00:00
Current running time 6047.14 seconds
------------- Evaluate Training Set -------------
------------- 29 batches -------------
all_pred 4
step: 2600
[Representation] Clustering scores: {'NMI': 0.705656383638045, 'ARI': 0.7472190397214792, 'AMI': 0.7055235176112404}
[Representation] ACC: 0.8953
[Representation] ACC sklearn: 0.5125
[Model] Clustering scores: {'NMI': 0.705656383638045, 'ARI': 0.7472190397214792, 'AMI': 0.7055235176112404}
[Model] ACC: 0.8953
[Model] ACC sklearn: 0.0371
------------- Evaluate Validation Set -------------
------------- 4 batches -------------
all_pred 4
step: 