In [None]:
!pip install datasets

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import argparse
import sys
import numpy as np
import datasets
from datasets import load_dataset, load_metric, DatasetDict
import random
import json
import time
import datetime
from tqdm import tqdm
import os

import torch
import torch.cuda
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.nn.parallel import DataParallel

from transformers import (
    AutoTokenizer,
    EncoderDecoderModel,
    BertGenerationDecoder,
    BertGenerationEncoder,
    BertTokenizer,
    BertModel)


# data

In [None]:
def get_dataset(data_files):
    loaded_dataset = load_dataset("json", data_files=data_files)

    train_dataset = loaded_dataset['train']
    test_dataset = loaded_dataset['test']
    return train_dataset, test_dataset

def get_samples(data_files, len_train_sample=100, len_test_sample=100):
    train_dataset, test_dataset = get_dataset(data_files)

    train_dataset_sample = train_dataset.shuffle().select(range(len_train_sample))
    test_dataset_sample = test_dataset.shuffle().select(range(len_test_sample))

    return train_dataset_sample, test_dataset_sample

def get_csv_dataset(data_file):
    dataset = load_dataset("csv", data_files=data_file)
    return dataset


def get_random(dataset):
    sampled_sentences = {}

    for data in dataset:
        label = data['label']
        sentence = data['sentence1']

        if label in sampled_sentences:
            sampled_sentences[label].append(sentence)
        else:
            sampled_sentences[label] = [sentence]

    final_sampled_sentences = {}
    for label, sentences in sampled_sentences.items():
        # Randomly sample 200 sentences for each label
        sampled = random.sample(sentences, k=min(200, len(sentences)))
        final_sampled_sentences[label] = sampled
    return final_sampled_sentences

# dataset

In [None]:
import spacy
nlp = spacy.load("en_core_web_sm")

class BertDataset(Dataset):
    def __init__(self, x, y, tokenizer, length=128, return_idx=False):
        super(BertDataset, self).__init__()
        self.tokenizer = tokenizer
        self.length = length
        self.x = x
        self.return_idx = return_idx
        self.y = torch.tensor(y)
        self.tokens_cache = {}

    def mask_nouns(self, text):
        doc = nlp(text)

        masked_text = ' '.join(['[MASK]' if token.pos_ in ['NOUN'] else token.text for token in doc])
        return masked_text

    def tokenize(self, x):
        # x = self.mask_nouns(x)
        # print(x)
        dic = self.tokenizer.batch_encode_plus(
            [x],  # input must be a list
            max_length=self.length,
            padding='max_length',
            truncation=True,
            return_token_type_ids=True,
            return_tensors="pt"
        )
        # print([x[0] for x in dic.values()])
        return [x[0] for x in dic.values()]  # get rid of the first dim

    def __getitem__(self, idx):
        int_idx = int(idx)
        assert idx == int_idx
        idx = int_idx
        if idx not in self.tokens_cache:
            self.tokens_cache[idx] = self.tokenize(self.x[idx])
        input_ids, token_type_ids, attention_mask = self.tokens_cache[idx]
        if self.return_idx:
            return input_ids, token_type_ids, attention_mask, self.y[idx], idx, self.x[idx]
        return input_ids, token_type_ids, attention_mask, self.y[idx]

    def __len__(self):
        return len(self.y)


from torch.utils.data import DataLoader, Dataset, Sampler
class TrainSampler(Sampler):
    def __init__(self, dataset, batch_size, sim_ratio=0.5):
        super().__init__(None)
        self.dataset = dataset
        self.batch_size = batch_size
        self.x = dataset.x
        self.y = dataset.y
        self.sim_ratio = sim_ratio
        self.num_pos_samples = int(batch_size * sim_ratio)
        print(f'train sampler with batch size = {batch_size} and postive sample ratio = {sim_ratio}')

        self.length = len(list(self.__iter__()))

    def __iter__(self):
        indices = list(range(len(self.y)))
        label_cluster = {}
        for i in indices:
            label = self.y[i].item()
            if label not in label_cluster:
                label_cluster[label] = []
            label_cluster[label].append(i)
        for key, value in label_cluster.items():
            random.shuffle(value)

        assert len(label_cluster[0]) > self.num_pos_samples, \
            f"only {len(label_cluster[0])} samples in each class, but {self.num_pos_samples} pos samples needed"

        # too time-consuming, i.e., O(|D||C|/|B|)s
        batch_indices = []
        flag = True
        while flag:
            # find a valid positive sample class
            available_classes = list(filter(lambda x: len(label_cluster[x]) >= self.num_pos_samples,
                                            list(range(max(self.y) + 1))))
            if len(available_classes) == 0:
                break
            class_count = random.choice(available_classes)

            # fill in positive samples
            batch_indices.append(label_cluster[class_count][-self.num_pos_samples:])
            del label_cluster[class_count][-self.num_pos_samples:]

            # fill in negative samples
            for i in range(self.batch_size - self.num_pos_samples):
                available_classes = list(filter(lambda x: len(label_cluster[x]) > 0, list(range(max(self.y) + 1))))
                if class_count in available_classes:
                    available_classes.remove(class_count)
                if len(available_classes) == 0:
                    flag = False
                    break
                rand_class = random.choice(available_classes)
                batch_indices[-1].append(label_cluster[rand_class].pop())

            random.shuffle(batch_indices[-1])

        random.shuffle(batch_indices)
        all = sum(batch_indices, [])

        return iter(all)

    def __len__(self):
        return self.length


class TrainSamplerMultiClass(Sampler):
    def __init__(self, dataset, batch_size, num_classes, samples_per_author):
        super().__init__(None)
        self.dataset = dataset
        self.batch_size = batch_size
        self.x = dataset.x
        self.y = dataset.y
        self.num_classes = num_classes
        self.samples_per_author = samples_per_author
        assert batch_size // num_classes * num_classes == batch_size, \
            f'batch size {batch_size} is not a multiple of num of classes {num_classes}'
        print(f'train sampler with batch size = {batch_size} and {num_classes} classes in a batch')
        self.length = len(list(self.__iter__()))

    def __iter__(self):
        indices = list(range(len(self.y)))
        label_cluster = {}
        for i in indices:
            label = self.y[i].item()
            if label not in label_cluster:
                label_cluster[label] = []
            label_cluster[label].append(i)

        assert len(label_cluster) > self.num_classes, \
            f'number of available classes {label_cluster} < required classes {self.num_classes}'

        num_samples_per_class_batch = self.batch_size // self.num_classes
        min_class_samples = min([len(x) for x in label_cluster.values()])
        assert min_class_samples > self.samples_per_author, \
            f"expected {self.samples_per_author} per author, but got {min_class_samples} in the dataset"
        class_samples_needed = self.samples_per_author // num_samples_per_class_batch * num_samples_per_class_batch

        dataset_matrix = []
        for key, value in label_cluster.items():
            random.shuffle(value)
            # value = [key] * len(value)    # debugging use
            dataset_matrix.append(torch.tensor(value[:class_samples_needed]).view(num_samples_per_class_batch, -1))

        tuples = torch.cat(dataset_matrix, dim=1).transpose(1, 0).split(1, dim=0)
        tuples = [x.flatten().tolist() for x in tuples]
        random.shuffle(tuples)
        all = sum(tuples, [])

        print(f'from dataset sampler: batch size {self.batch_size}, num of classes in a batch {self.num_classes}, '
              f'num of samples per author in total {self.samples_per_author} (specified) / {class_samples_needed} (true).'
              f'dataset size {len(all)}')

        return iter(all)

    def __len__(self):
        return self.length


class TrainSamplerMultiClassUnit(Sampler):
    def __init__(self, dataset, sample_unit_size):
        super().__init__(None)
        self.x = dataset.x
        self.y = dataset.y
        self.sample_unit_size = sample_unit_size
        print(f'train sampler with sample unit size {sample_unit_size}')
        self.length = len(list(self.__iter__()))

    def __iter__(self):
        indices = list(range(len(self.y)))
        label_cluster = {}
        for i in indices:
            label = self.y[i].item()
            if label not in label_cluster:
                label_cluster[label] = []
            label_cluster[label].append(i)

        dataset_matrix = []
        for key, value in label_cluster.items():
            random.shuffle(value)
            num_valid_samples = len(value) // self.sample_unit_size * self.sample_unit_size
            dataset_matrix.append(torch.tensor(value[:num_valid_samples]).view(self.sample_unit_size, -1))

        tuples = torch.cat(dataset_matrix, dim=1).transpose(1, 0).split(1, dim=0)
        tuples = [x.flatten().tolist() for x in tuples]
        random.shuffle(tuples)
        all = sum(tuples, [])

        print(f'from dataset sampler: original dataset size {len(self.y)}, resampled dataset size {len(all)}. '
              f'sample unit size {self.sample_unit_size}')

        return iter(all)

    def __len__(self):
        return self.length

# util

In [None]:
class AverageMeter(object):
    """
    Computes and stores the average and current value
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# model

In [None]:
# def train_step(self, train_data):
#         results = {}

#         self.content_encoder.encoder.to(self.device)
#         self.style_encoder.model.to(self.device)

#         tokenized_content_data = self.content_encoder.prepare_data(train_data)
#         tokenized_style_data = self.style_encoder.prepare_data(train_data)

#         content_trainloader = DataLoader(tokenized_content_data, batch_size=self.batch_size, shuffle=True)
#         style_trainloader = DataLoader(tokenized_style_data, batch_size=self.batch_size, shuffle=True)

#         # Training Loop
#         for epoch in range(self.num_epochs):
#             self.content_encoder.encoder.train()
#             self.style_encoder.model.train()
#             total_loss = 0

#             for i in range(self.num_iters):
#                 for batch1, batch2 in zip(content_trainloader, style_trainloader):
#                     content_embedding = self.content_encoder.get_content_embeddings(batch1).to(self.device)
#                     style_embedding = self.style_encoder.get_style(batch2).to(self.device)

#                     mi_estimator = CLUB(content_embedding, style_embedding, hidden_size=self.hidden_size).to(self.device)
#                     mi_estimator.eval()
#                     sampler_loss = mi_estimator(content_embedding, style_embedding)

#                     self.content_opt.zero_grad()
#                     self.style_opt.zero_grad()

#                     sampler_loss.backward()

#                     self.content_opt.step()
#                     self.style_opt.step()

#                 for j in range(5):
#                     mi_estimator.train()
#                     for batch1, batch2 in zip(content_trainloader, style_trainloader):
#                         content_embedding = self.content_encoder.get_content_embeddings(batch1)
#                         style_embedding = self.style_encoder.get_style(batch2)

#                         mi_loss = mi_estimator.learning_loss(content_embedding, style_embedding)
#                         mi_loss = torch.mean(mi_loss)

#                         mi_optimizer.zero_grad()
#                         mi_loss.backward()
#                         mi_optimizer.step()

#                 mi_loss = mi_estimator.learning_loss(content_embedding, style_embedding)
#                 mi_loss = torch.mean(mi_loss)

#                 if i % 50 == 0:
#                     print(f"Step {i+1} - Loss: {mi_loss:.4f}")

#                 loss = mi_loss.cpu().detach().numpy().tolist()
#                 results["step {}".format(i)] = loss

#             mi_loss = mi_estimator.learning_loss(content_embedding, style_embedding)
#             mi_loss = torch.mean(mi_loss)
#             total_loss += mi_loss

#         average_loss = total_loss / len(content_trainloader)
#         results["EPOCH {}:".format(i)] = average_loss
#         print(f"Epoch {epoch+1}/{self.num_epochs} - Average Loss: {average_loss:.4f}")

#         with open('/content/drive/MyDrive/msc_project/model/contrastive/club/mi_loss_step.json', 'w') as json_file:
#             json.dump(results, json_file, indent=4)
#         return results

#     def train_batch(self, train_data, test_data):
#         results = {}

#         self.content_encoder.encoder.to(self.device)
#         self.style_encoder.model.to(self.device)

#         tokenized_content_data = self.content_encoder.prepare_data(train_data)
#         tokenized_style_data = self.style_encoder.prepare_data(train_data)

#         content_trainloader = DataLoader(tokenized_content_data, batch_size=self.batch_size, shuffle=True)
#         style_trainloader = DataLoader(tokenized_style_data, batch_size=self.batch_size, shuffle=True)

#         sample_dim = 768
#         # mi_estimator = CLUB(content_embedding, style_embedding, hidden_size=self.hidden_size).to(self.device)
#         mi_estimator = CLUB(sample_dim, sample_dim, self.hidden_size).cuda()
#         mi_optimizer = torch.optim.Adam(mi_estimator.parameters(), lr=self.lr)

#         epoch_losses = []
#         # Training Loop
#         for epoch in range(self.num_epochs):
#             print('epoch', epoch)
#             self.content_encoder.encoder.train()
#             self.style_encoder.model.train()

#             iter_losses = []
#             tqdm_iter = tqdm(range(self.num_iters), desc="Iterations", leave=False, total=self.num_iters)

#             for i in tqdm_iter:
#                 # print('iterations', i)
#                 iter_loss = 0
#                 count = 0
#                 sampler_loss_all = 0
#                 for batch1, batch2 in zip(content_trainloader, style_trainloader):
#                     if count % 50 == 0:
#                         print('count', count)
#                         if count != 0:
#                           print(f'sampler loss {sampler_loss_all / count}')
#                     count += 1
#                     # 32 * 512 * 768
#                     content_embedding = self.content_encoder.get_content_embeddings(batch1)
#                     style_embedding = self.style_encoder.get_style(batch2)

#                     mi_estimator.eval()
#                     sampler_loss = mi_estimator(content_embedding, style_embedding)
#                     # print(sampler_loss)

#                     self.content_opt.zero_grad()
#                     self.style_opt.zero_grad()

#                     sampler_loss.backward()

#                     self.content_opt.step()
#                     self.style_opt.step()

#                     sampler_loss_all += sampler_loss.item()


#                 for j in range(5):
#                     print(j)
#                     mi_loss_temp = 0
#                     mi_estimator.train()
#                     for batch1, batch2 in zip(content_trainloader, style_trainloader):
#                         content_embedding = self.content_encoder.get_content_embeddings(batch1)
#                         style_embedding = self.style_encoder.get_style(batch2)

#                         # mi_optimizer = torch.optim.Adam(mi_estimator.parameters(), lr=self.lr)

#                         mi_loss = mi_estimator.learning_loss(content_embedding, style_embedding)
#                         mi_loss = torch.mean(mi_loss)

#                         mi_optimizer.zero_grad()
#                         mi_loss.backward()
#                         mi_optimizer.step()

#                         iter_loss += mi_loss.item()
#                         mi_loss_temp += mi_loss.item()
#                     print('mi loss', mi_loss_temp / len(content_trainloader))

#                 iter_loss /= 10
#                 iter_losses.append(iter_loss)
#                 tqdm_iter.set_postfix(loss=iter_loss)
#                 if i % 10 == 0:
#                     # tqdm.write(f"Step {i+1} - Loss: {iter_loss:.4f}")
#                     # loss = iter_loss.cpu().detach().numpy().tolist()
#                     results["step {}".format(i)] = iter_loss

#                 # print(mi_estimator(content_embedding, style_embedding))

#             style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder_1.pt"
#             torch.save(self.style_encoder.model.state_dict(), style_checkpoint)

#             self.test_step(style_checkpoint, train_data, test_data, '')


#             epoch_loss = sum(iter_losses) / len(iter_losses)
#             epoch_losses.append(epoch_loss)
#             results["EPOCH {}:".format(i)] = epoch_loss
#             print(f"Epoch {epoch+1}/{self.num_epochs} - Average Loss: {epoch_loss:.4f}")

#         with open('/content/drive/MyDrive/msc_project/model/contrastive/club/mi_loss_batch.json', 'w') as json_file:
#             json.dump(results, json_file, indent=4)
#         return results

In [None]:
class ContentEncoder():
    def __init__(self,
                 checkpoint,
                 train_data,
                 test_data,):
        self.train_data = train_data
        self.test_data = test_data
        self.checkpoint = checkpoint

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.encoder = BertGenerationEncoder.from_pretrained(self.checkpoint)
        self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
        self.lr = 2e-5
        self.max_length = 256

    def prepare_data(self, data):
        # data = Dataset.from_dict(data)
        processed_data = data.map(self.prepare_inputs,
            # batched=True,
            # batch_size=16,
        )
        processed_data = processed_data.remove_columns(["sentence1", "label"])
        processed_data.set_format(
            type="torch",
            columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
        )
        return processed_data

    def prepare_inputs(self, examples):
        tokenizer = self.tokenizer
        inputs = tokenizer.encode_plus(examples["sentence1"], truncation=True, padding="max_length", max_length=self.max_length)
        outputs = tokenizer.encode_plus(examples["sentence1"], truncation=True, padding="max_length", max_length=self.max_length)

        examples["input_ids"] = inputs.input_ids
        examples["attention_mask"] = inputs.attention_mask
        examples["decoder_input_ids"] = outputs.input_ids
        examples["decoder_attention_mask"] = outputs.attention_mask
        examples["labels"] = outputs.input_ids.copy()
        return examples


    def get_content_embeddings(self, data):
        encoder = self.encoder.to(self.device)
        # data = self.prepare_data(data)
        # data_input_ids = data['input_ids'].to(self.device)
        # data_attention_mask = data['attention_mask'].to(self.device)
        x = data

        # with torch.no_grad():
        # outputs = encoder(input_ids=data_input_ids, attention_mask=data_attention_mask)
        outputs = encoder(input_ids=x[0], attention_mask=x[2])

        last_hidden_states = outputs.last_hidden_state
        return last_hidden_states


class LogisticRegression(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim, dropout=0):
        super().__init__()
        print(f'Logistic Regression classifier of dim ({in_dim} {hid_dim} {out_dim})')

        self.nn = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(in_dim, hid_dim, bias=True),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(hid_dim, out_dim, bias=True),
        )

    def forward(self, x, return_feat=False):
        out = self.nn(x)
        if return_feat:
            return out, x
        return out


class BertClassifier(nn.Module):
    FEAT_LEN = 768

    def __init__(self, raw_bert, classifier):
        super().__init__()
        self.bert = raw_bert
        self.fc = classifier

    def forward(self, x, return_feat=False):
        # x is a tokenized input
        # feature = self.bert(input_ids=x[0], token_type_ids=x[1], attention_mask=x[2])
        feature = self.bert(input_ids=x[0], attention_mask=x[2])
        # print(feature.last_hidden_state.shape)
        # out = self.fc(feature.pooler_output.flatten(1))       # not good for our task     # (BS, E)
        out = self.fc(feature.last_hidden_state.flatten(1))  # (BS, T, E)
        if return_feat:
            return out, feature.last_hidden_state, feature
        return out


def load_model_dic(model, ckpt_path, verbose=True, strict=True):
    """
    Load weights to model and take care of weight parallelism
    """
    assert os.path.exists(ckpt_path), f"trained model {ckpt_path} does not exist"

    try:
        model.load_state_dict(torch.load(ckpt_path), strict=strict)
    except:
        state_dict = torch.load(ckpt_path)
        state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()}
        model.load_state_dict(state_dict, strict=strict)
    if verbose:
        print(f'Model loaded: {ckpt_path}')

    return model


def save_model(ckpt_dir, cp_name, model):
    """
    Create directory /Checkpoint under exp_data_path and save encoder as cp_name
    """
    os.makedirs(ckpt_dir, exist_ok=True)
    saving_model_path = os.path.join(ckpt_dir, cp_name)
    if isinstance(model, torch.nn.DataParallel):
        model = model.module  # convert to non-parallel form
    torch.save(model.state_dict(), saving_model_path)
    print(f'Model saved: {saving_model_path}')


class StyleEncoder():
    def __init__(
            self,
            checkpoint,
            train_data,
            test_data):

        self.checkpoint = checkpoint
        self.train_data = train_data
        self.test_data = test_data

        # self.model = BertModel.from_pretrained(checkpoint)
        num_tokens, hidden_dim, out_dim, dropout = 256, 512, 150, 0.35
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased', padding=True, truncation=True)
        extractor = BertModel.from_pretrained('bert-base-cased')
        model = BertClassifier(extractor, LogisticRegression(768 * num_tokens, hidden_dim, out_dim, dropout=dropout))
        self.model = load_model_dic(model, checkpoint, verbose=True, strict=True)
        # self.tokenizer = BertTokenizer.from_pretrained(checkpoint, padding=True, truncation=True)
        # self.model = BertModel.from_pretrained(checkpoint)
        self.parameters = self.model.parameters()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.max_length = 256

    def prepare_inputs(self, examples):
        tokenizer = self.tokenizer
        inputs = tokenizer.encode_plus(examples['sentence1'], padding="max_length", truncation=True, max_length=self.max_length)
        return inputs

    def prepare_data(self, data):
        processed_data = data.map(self.prepare_inputs)
        processed_data = processed_data.remove_columns(["sentence1", "label"])
        processed_data.set_format(type="torch",)
        return processed_data

    def get_style(self, encoded_dict):
        # data_input_ids = encoded_dict['input_ids'].to(self.device)
        # data_attention_mask = encoded_dict['attention_mask'].to(self.device)
        # data_token_type_ids = encoded_dict['token_type_ids'].to(self.device)
        # x = data_input_ids, data_token_type_ids, data_attention_mask
        x = encoded_dict
            # model_output = self.model(input_ids=data_input_ids, attention_mask=data_attention_mask)
        # with torch.no_grad():
        # pred, feats, model_output = self.model(x, return_feat=True)
        # feats = self.model.bert(input_ids=x[0], attention_mask=x[2])
        pred, feats, model_output = self.model(x, return_feat=True)
        # feats = feats.last_hidden_state
        # sentence_embedding = self.mean_pooling(model_output, encoded_dict['attention_mask'])
        # last_hidden_states = model_output.last_hidden_state
        # return last_hidden_states
        return pred, feats


class DualEncoder(nn.Module):
    def __init__(self,
                 content_encoder,
                 style_encoder,
                 num_epochs,
                 batch_size,
                 num_iters,
                 log_step,
                 lr=2e-5,
                 hidden_size=500):

        super(DualEncoder, self).__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.content_encoder = content_encoder
        self.style_encoder = style_encoder
        self.lr = lr
        self.num_iters = num_iters
        self.hidden_size = hidden_size
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.log_step = log_step
        self.content_opt = torch.optim.Adam(self.content_encoder.encoder.parameters(), self.lr)
        self.style_opt = torch.optim.Adam(self.style_encoder.model.parameters(), lr = self.lr)

    def test_step(self, style_checkpoint, train_loader, test_data, test_loader, output_file):
        ranking_eval = RegionComparison(checkpoint = style_checkpoint,
                                        train_loader=train_loader,
                                        test_data=test_data,
                                        test_loader=test_loader,
                                        output_file=output_file)

        # result_list = ranking_eval.compare_to_regions(train_data, test_data)
        # with open(output_file, 'w') as f:
        #     for dictionary in result_list:
        #         json_string = json.dumps(dictionary)
        #         f.write(json_string + '\n')

        metrics_file = "/content/drive/MyDrive/msc_project/model/contrastive/club/metrics.json"
        accuracy_k = "/content/drive/MyDrive/msc_project/model/contrastive/club/accuracy_k.json"

        metrics = ranking_eval.compute_metrics()
        # with open(metrics_file, 'w') as f:
        #     json.dump(metrics, f, indent=4)
        # metrics = {}

        result_topk = {}

        # for i in range(1,2):
        #     metrics_topk = ranking_eval.compute_top_k(i)
        #     result_topk[f'Accuracy top-{i}'] = metrics_topk

        # metrics_top10 = ranking_eval.compute_top_k(10)
        # result_topk[f'Accuracy top-10'] = metrics_top10

        # metrics_top20 = ranking_eval.compute_top_k(20)
        # result_topk[f'Accuracy top-20'] = metrics_top20

        # with open(accuracy_k, 'w') as f:
        #     json.dump(result_topk, f, indent=4)
        return metrics, result_topk


    def train(self, train_dict, test_dic, val_dic=None):
        self.content_encoder.encoder.to(self.device)
        self.style_encoder.model.to(self.device)

        train_x, train_y = train_dict['content'].tolist(), train_dict['Target'].tolist()
        test_x, test_y = test_dic['content'].tolist(), test_dic['Target'].tolist()

        if val_dic is not None:
            val_x, val_y = val_dic['content'].tolist(), val_dic['Target'].tolist()

        from transformers import BertTokenizer, BertModel
        tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
        extractor = BertModel.from_pretrained('bert-base-cased')

        num_tokens, hidden_dim = 256, 512

        train_set = BertDataset(train_x, train_y, tokenizer, num_tokens)
        test_set = BertDataset(test_x, test_y, tokenizer, num_tokens)

        if val_dic is not None:
            val_set = BertDataset(val_x, val_y, tokenizer, num_tokens)

        temperature, sample_unit_size = 0.1, 2
        # print(f'coefficient, temperature, sample_unit_size = {coefficient, temperature, sample_unit_size}')

        # load data
        train_sampler = TrainSamplerMultiClassUnit(train_set, sample_unit_size=sample_unit_size)
        train_loader = DataLoader(train_set, batch_size=self.batch_size, sampler=train_sampler, shuffle=False,
                                  num_workers=4, pin_memory=True, drop_last=False)
        # train_loader = DataLoader(train_set, batch_size=base_bs * ngpus, shuffle=True,
        #                           num_workers=4 * ngpus, pin_memory=True, drop_last=False)
        test_loader = DataLoader(test_set, batch_size=self.batch_size, shuffle=False, num_workers=4,
                                pin_memory=True, drop_last=False)

        if val_dic is not None:
            val_loader = DataLoader(val_set, batch_size=self.batch_size, shuffle=False, num_workers=4,
                                    pin_memory=True, drop_last=True)

        sample_dim = 196608
        # sample_dim = 768
        mi_estimator = CLUB(sample_dim, sample_dim, self.hidden_size).cuda()
        mi_optimizer = torch.optim.Adam(mi_estimator.parameters(), lr=1e-3)
        supcon = SupConLoss_contrastiveAA()
        criterion = nn.CrossEntropyLoss()

        # style_checkpoint = f"/content/drive/MyDrive/msc_project/model/contrastive/club_blogs50/style_encoder_supcon_0.pt"
        # self.test_step(style_checkpoint, train_loader, test_dic, test_loader, '')


        for epoch in range(5):  # num_iter_phase1 is the number of iterations for this phase
            train_estimator_loss = AverageMeter()
            pg = tqdm(train_loader, leave=False, total=len(train_loader), disable=False)
            for i, (x1, x2, x3, y) in enumerate(pg):
                x, y = (x1.cuda(), x2.cuda(), x3.cuda()), y.cuda()

                # Encode the samples using style and content encoders
                # These encoders are not updated in this phase
                with torch.no_grad():  # Ensure that encoders are not trained
                    content_embedding = self.content_encoder.get_content_embeddings(x)
                    pred, style_embedding = self.style_encoder.get_style(x)

                # Train the MI estimator
                mi_estimator.train()
                mi_loss = mi_estimator.learning_loss(content_embedding.flatten(1), style_embedding.flatten(1))
                # mi_loss = mi_estimator.learning_loss(content_embedding, style_embedding)

                mi_optimizer.zero_grad()
                mi_loss.backward()
                mi_optimizer.step()

                train_estimator_loss.update(mi_loss.item())

                pg.set_postfix({
                    'train eistimator loss': '{:.6f}'.format(train_estimator_loss.avg),
                    'epoch': '{:03d}'.format(epoch)
                })
            print(f'epoch {epoch}, train eistimator loss {train_estimator_loss.avg}')

        print()
        epoch_losses = []

        for epoch in range(self.num_epochs):

            # numbers = list(range(len(train_loader)))
            # random_numbers = random.sample(numbers, 500)
            # print(random_numbers)

            self.content_encoder.encoder.train()
            self.style_encoder.model.train()
            mi_estimator.eval()

            count = 0
            sampler_loss_all = 0

            train_acc = AverageMeter()
            train_loss = AverageMeter()
            train_sampler_loss = AverageMeter()
            train_supcon_loss = AverageMeter()
            train_cls_loss = AverageMeter()
            pg = tqdm(train_loader, leave=False, total=len(train_loader), disable=False)
            for i, (x1, x2, x3, y) in enumerate(pg):
                # if i not in random_numbers:
                #     continue

                x, y = (x1.cuda(), x2.cuda(), x3.cuda()), y.cuda()
                content_embedding = self.content_encoder.get_content_embeddings(x)
                pred, style_embedding = self.style_encoder.get_style(x)
                # print('pred', pred.argmax(1))
                # print('y', y)

                # sampler_loss = mi_estimator(style_embedding, content_embedding)
                sampler_loss = mi_estimator(content_embedding.flatten(1), style_embedding.flatten(1))
                supcon_loss = supcon(style_embedding.flatten(1), y.long())
                cls_loss = criterion(pred, y.long())

                coe = 5
                # loss = coe * sampler_loss + supcon_loss + cls_loss
                loss = coe * sampler_loss + supcon_loss

                acc = (pred.argmax(1) == y).sum().item() / len(y)

                train_acc.update(acc)
                train_loss.update(loss.item())
                train_sampler_loss.update(sampler_loss.item())
                train_supcon_loss.update(supcon_loss.item())
                train_cls_loss.update(cls_loss.item())

                self.content_opt.zero_grad()
                self.style_opt.zero_grad()

                loss.backward()

                self.content_opt.step()
                self.style_opt.step()


                pg.set_postfix({
                    'train acc': '{:.6f}'.format(train_acc.avg),
                    'train sampler loss': '{:.6f}'.format(train_sampler_loss.avg),
                    'train supcon loss': '{:.6f}'.format(train_supcon_loss.avg),
                    'train cls loss': '{:.6f}'.format(train_cls_loss.avg),
                    'train L': '{:.6f}'.format(train_loss.avg),
                    'epoch': '{:03d}'.format(epoch)
                })

            print(f'epoch {epoch}, sampler loss {train_sampler_loss.avg}, supcon loss {train_supcon_loss.avg}, cls loss {train_cls_loss.avg}, loss {train_loss.avg}, style acc {train_acc.avg}')

            # style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder_1.pt"
            # torch.save(self.style_encoder.model.state_dict(), style_checkpoint)
            # self.test_step(style_checkpoint, train_dict, test_dic, '')

            for j in range(5):
                # numbers = list(range(len(train_loader)))
                # random_numbers = random.sample(numbers, 500)
                # print(random_numbers)

                train_estimator_loss = AverageMeter()
                mi_estimator.train()
                pg = tqdm(train_loader, leave=False, total=len(train_loader), disable=False)
                for i, (x1, x2, x3, y) in enumerate(pg):
                    # if i not in random_numbers:
                    #     continue

                    x, y = (x1.cuda(), x2.cuda(), x3.cuda()), y.cuda()
                    content_embedding = self.content_encoder.get_content_embeddings(x)
                    pred, style_embedding = self.style_encoder.get_style(x)

                    mi_loss = mi_estimator.learning_loss(content_embedding.flatten(1), style_embedding.flatten(1))
                    # mi_loss = mi_estimator.learning_loss(content_embedding, style_embedding)
                    mi_optimizer.zero_grad()
                    mi_loss.backward()
                    mi_optimizer.step()

                    train_estimator_loss.update(mi_loss.item())

                    pg.set_postfix({
                        'train eistimator loss': '{:.6f}'.format(train_estimator_loss.avg),
                        'iteration': '{:03d}'.format(j)
                    })

                print(f'iteration {j}, train eistimator loss {train_estimator_loss.avg}')

            style_checkpoint = f"/content/drive/MyDrive/msc_project/model/contrastive/club_100_150/style_encoder_supcon1_{epoch}.pt"
            torch.save(self.style_encoder.model.state_dict(), style_checkpoint)
            content_checkpoint = f"/content/drive/MyDrive/msc_project/model/contrastive/club_100_150/content_encoder_supcon1_{epoch}.pt"
            torch.save(self.content_encoder.encoder.state_dict(), content_checkpoint)
            self.test_step(style_checkpoint, train_loader, test_dic, test_loader, '')


# loss

In [None]:
class CLUB(nn.Module):
    def __init__(self, x_dim, y_dim, hidden_size):
        super(CLUB, self).__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # self.x_samples = x_samples
        # self.y_samples = y_samples
        self.hidden_size = hidden_size

        # self.p_mu = nn.Sequential(nn.Linear(self.x_samples.shape[2], self.hidden_size // 2),
        #                         nn.ReLU(),
        #                         nn.Linear(self.hidden_size // 2, self.y_samples.shape[2]))

        # self.p_logvar = nn.Sequential(nn.Linear(self.x_samples.shape[2], self.hidden_size // 2),
        #                             nn.ReLU(),
        #                             nn.Linear(self.hidden_size // 2, self.y_samples.shape[2]),
        #                             nn.Tanh())
        self.p_mu = nn.Sequential(nn.Linear(x_dim, self.hidden_size // 2),
                                nn.ReLU(),
                                nn.Linear(self.hidden_size // 2, y_dim))

        self.p_logvar = nn.Sequential(nn.Linear(x_dim, self.hidden_size // 2),
                                    nn.ReLU(),
                                    nn.Linear(self.hidden_size // 2, y_dim),
                                    nn.Tanh())

    def get_mu_logvar(self, data):
        # print(data.shape)
        mu = self.p_mu(data)
        logvar = self.p_logvar(data)
        return mu, logvar

    def forward(self, x_samples, y_samples):
        mu, logvar = self.get_mu_logvar(x_samples)
        positive = torch.div(-(mu - y_samples)**2, 2. * torch.exp(logvar))

        prediction_1 = mu.unsqueeze(1)
        y_samples_1 = y_samples.unsqueeze(0)
        negative = - torch.div(torch.mean((y_samples_1 - prediction_1)**2, dim=1), 2. * torch.exp(logvar))
        res = torch.mean(positive.sum(dim = -1) - negative.sum(dim = -1))
        return res

    def loglikeli(self, x_samples, y_samples):
        mu, logvar = self.get_mu_logvar(x_samples)
        # res = torch.mean((-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1), dim=0)
        res = torch.mean((-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1), dim=0)
        return res

    def learning_loss(self, x_samples, y_samples):
        res = - self.loglikeli(x_samples, y_samples)
        return res

class SupConLoss_contrastiveAA(nn.Module):
    def __init__(self, temperature=0.1, margin=0.2):
        """
        Implementation of the loss described in the paper Supervised Contrastive Learning :
        https://arxiv.org/abs/2004.11362
        :param temperature: int
        """
        super(SupConLoss_contrastiveAA, self).__init__()
        self.temperature = temperature
        self.cos = nn.CosineSimilarity(dim=-1)

    def forward(self, projections, targets):
        """
        :param projections: torch.Tensor, shape [batch_size, projection_dim]
        :param targets: torch.Tensor, shape [batch_size]
        :return: torch.Tensor, scalar
        """
        device = torch.device("cuda") if projections.is_cuda else torch.device("cpu")

        dot_product_tempered = self.cos(projections.unsqueeze(1), projections.unsqueeze(0)) / self.temperature

        exp_dot_tempered = (
            torch.exp(dot_product_tempered - torch.max(dot_product_tempered, dim=1, keepdim=True)[0]) + 1e-5
        )

        # mask_similar_class = (targets.unsqueeze(1).repeat(1, targets.shape[0]) == targets).to(device)
        # mask_anchor_out = (1 - torch.eye(exp_dot_tempered.shape[0])).to(device)
        # mask_combined_pos = mask_similar_class * mask_anchor_out

        # mask_diff_class = (targets.unsqueeze(1).repeat(1, targets.shape[0]) != targets).to(device)
        # mask_combined_neg = mask_diff_class * mask_anchor_out

        mask_similar_class = (targets.unsqueeze(1) == targets.unsqueeze(0)).to(device)
        mask_anchor_out = ~torch.eye(projections.shape[0], dtype=torch.bool).to(device)  # Mask self-similarities

        # Positive and negative masks combined
        mask_combined_pos = mask_similar_class * mask_anchor_out
        mask_combined_neg = (~mask_similar_class) * mask_anchor_out

        cardinality_pos = mask_combined_pos.sum(dim=1).clamp(min=1)

        # cardinality_pos = torch.sum(mask_combined_pos, dim=1)

        # for i in range(cardinality_pos.size(0)):
        #     if cardinality_pos[i]==0:
        #         cardinality_pos[i] = 1
        # print(1)

        exp_sum_neg = torch.sum(exp_dot_tempered * mask_combined_neg, dim=1)
        prob = exp_dot_tempered / (exp_dot_tempered + exp_sum_neg.view(-1, 1) + 1e-5)

        log_prob = -torch.log(prob) * mask_combined_pos
        # for i in range(cardinality_pos.size(0)):
        #     if cardinality_pos[i]==0:
        #         cardinality_pos[i] = 1

        total_loss = torch.mean(torch.sum(log_prob, dim=1) / cardinality_pos)

        return total_loss


# class CLUB(nn.Module):  # CLUB: Mutual Information Contrastive Learning Upper Bound
#     '''
#         This class provides the CLUB estimation to I(X,Y)
#         Method:
#             forward() :      provides the estimation with input samples
#             loglikeli() :   provides the log-likelihood of the approximation q(Y|X) with input samples
#         Arguments:
#             x_dim, y_dim :         the dimensions of samples from X, Y respectively
#             hidden_size :          the dimension of the hidden layer of the approximation network q(Y|X)
#             x_samples, y_samples : samples from X and Y, having shape [sample_size, x_dim/y_dim]
#     '''
#     def __init__(self, x_dim, y_dim, hidden_size):
#         super(CLUB, self).__init__()
#         # p_mu outputs mean of q(Y|X)
#         #print("create CLUB with dim {}, {}, hiddensize {}".format(x_dim, y_dim, hidden_size))
#         self.p_mu = nn.Sequential(nn.Linear(x_dim, hidden_size//2),
#                                        nn.ReLU(),
#                                        nn.Linear(hidden_size//2, y_dim))
#         # p_logvar outputs log of variance of q(Y|X)
#         self.p_logvar = nn.Sequential(nn.Linear(x_dim, hidden_size//2),
#                                        nn.ReLU(),
#                                        nn.Linear(hidden_size//2, y_dim),
#                                        nn.Tanh())

#     def get_mu_logvar(self, x_samples):
#         mu = self.p_mu(x_samples)
#         logvar = self.p_logvar(x_samples)
#         return mu, logvar

#     def forward(self, x_samples, y_samples):
#         mu, logvar = self.get_mu_logvar(x_samples)
#         # print(mu)
#         # print(logvar)

#         # log of conditional probability of positive sample pairs
#         positive = - (mu - y_samples)**2 /2./logvar.exp()
#         # print('positive', positive)
#         # print(positive.sum(dim = -1))

#         prediction_1 = mu.unsqueeze(1)          # shape [nsample,1,dim]
#         # print(prediction_1.shape)
#         y_samples_1 = y_samples.unsqueeze(0)    # shape [1,nsample,dim]

#         # log of conditional probability of negative sample pairs
#         negative = - ((y_samples_1 - prediction_1)**2).mean(dim=1)/2./logvar.exp()
#         # print('negative', negative)
#         # print(negative.sum(dim = -1))

#         return (positive.sum(dim = -1) - negative.sum(dim = -1)).mean()

#     def loglikeli(self, x_samples, y_samples): # unnormalized loglikelihood
#         mu, logvar = self.get_mu_logvar(x_samples)
#         return (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0)

#     def learning_loss(self, x_samples, y_samples):
#         return - self.loglikeli(x_samples, y_samples)


# evaluation

In [None]:
from sklearn.cluster import KMeans

class RegionComparison():
    def __init__(self,
                 checkpoint,
                 train_loader,
                 test_data,
                 test_loader,
                 output_file = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"):
        self.checkpoint = checkpoint
        # self.model =  BertModel.from_pretrained(self.checkpoint)
        # self.tokenizer = BertTokenizer.from_pretrained(self.checkpoint, padding=True, truncation=True)
        num_tokens, hidden_dim, out_dim, dropout = 256, 512, 150, 0.35
        self.num_tokens = num_tokens
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased', padding=True, truncation=True)
        extractor = BertModel.from_pretrained('bert-base-cased')
        model = BertClassifier(extractor, LogisticRegression(768 * num_tokens, hidden_dim, out_dim, dropout=dropout))
        self.model = load_model_dic(model, checkpoint, verbose=True, strict=True)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.test_data = test_data
        self.output_file = output_file
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        self.max_length = 256

    def mean_pooling(self, model_output, attention_mask):
        # attention_mask: batch_size * max_length
        token_embeddings = model_output[0] # batch_size * max_length * hidden_dim
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() # batch_size * max_length * hidden_dim
        input_mask_expanded = input_mask_expanded.to(self.device)
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def test_AA_acc(self):
        # test_x, test_y = self.test_data['content'].tolist(), self.test_data['Target'].tolist()
        # test_set = BertDataset(test_x, test_y, self.tokenizer, self.num_tokens)
        # test_loader = DataLoader(test_set, batch_size=24, shuffle=False, num_workers=4,
        #                             pin_memory=True)
        test_loader = self.test_loader

        all_preds = []
        all_labels = []
        self.model.eval()
        with torch.no_grad():
            for i, (x1, x2, x3, y) in enumerate(test_loader):  # for x1, x2, x3, y in train_set:
                x, y = (x1.cuda(), x2.cuda(), x3.cuda()), y.cuda()
                pred, feats, model_output = self.model(x, return_feat=True)
                all_preds.append(pred.argmax(1).cpu().detach().numpy())
                all_labels.append(y.cpu().detach().numpy())

        all_preds = np.concatenate(all_preds, axis=0).tolist()
        all_labels = np.concatenate(all_labels, axis=0).tolist()
        # print(all_preds)
        # print(all_labels)
        # Compare predictions to labels
        correct_predictions = sum(p == l for p, l in zip(all_preds, all_labels))
        # Calculate accuracy
        accuracy = correct_predictions / len(all_preds)
        print('accuracy', accuracy)

    # Compute and return average embeddings for each region
    def get_average_embeddings(self, data):
        self.model.eval()

        device = self.device
        # tokenizer = self.tokenizer
        model = self.model
        label_embeddings = {}
        label_counts = {}
        for i, (x1, x2, x3, y) in enumerate(data):
            x, y = (x1.cuda(), x2.cuda(), x3.cuda()), y.cuda()
            with torch.no_grad():
                pred, feats, model_output = self.model(x, return_feat=True)
                batch_embeddings = feats.flatten(1).cpu()

            for i, label in enumerate(y):
                label = label.item()  # Ensure the label is a simple value (if tensor)
                embedding = batch_embeddings[i]

                if label not in label_embeddings:
                    label_embeddings[label] = embedding
                    label_counts[label] = 1
                else:
                    label_embeddings[label] += embedding
                    label_counts[label] += 1

            # Clear memory after processing the batch
            del x, feats
            torch.cuda.empty_cache()

        # Compute mean embeddings for each label
        mean_embeddings = {}
        for label, total_embedding in label_embeddings.items():
            mean_embeddings[label] = total_embedding / label_counts[label]

        return mean_embeddings

        # # Group the data by label
        # label_data = {}
        # # for example in data:
        # for idx, example in data.iterrows():
        #     # label = example['label']
        #     # sentence = example['sentence1']
        #     label = example['Target']
        #     sentence = example['content']
        #     if label not in label_data:
        #         label_data[label] = []
        #     label_data[label].append(sentence)

        # # Encode each sentence and extract the embeddings
        # label_embeddings = {}
        # label_counts = {}
        # for label, sentences in label_data.items():
        #     sentence_embeddings = []
        #     for sentence in sentences:
        #         encoded_dict = tokenizer.encode_plus(
        #             sentence,
        #             # padding=True,
        #             truncation=True,
        #             return_tensors='pt',
        #             max_length=self.max_length,
        #             padding="max_length"
        #         )
        #         # Compute token embeddings
        #         with torch.no_grad():
        #             data_input_ids = encoded_dict['input_ids'].to(self.device)
        #             data_attention_mask = encoded_dict['attention_mask'].to(self.device)
        #             data_token_type_ids = encoded_dict['token_type_ids'].to(self.device)
        #             x = data_input_ids, data_token_type_ids, data_attention_mask
        #             # model_output = self.model(input_ids=data_input_ids, attention_mask=data_attention_mask)
        #             pred, feats, model_output = self.model(x, return_feat=True)

        #         # Performing mean pooling
        #         # sentence_embedding = self.mean_pooling(model_output, encoded_dict['attention_mask'].to(device))
        #         sentence_embedding = feats.flatten(1)
        #         sentence_embeddings.append(sentence_embedding)

        #     label_embeddings[label] = torch.stack(sentence_embeddings)

        # mean_embeddings = {}
        # for label, embeddings in label_embeddings.items():
        #     mean_embedding = torch.mean(embeddings, dim=0)
        #     mean_embeddings[label] = mean_embedding

        # return mean_embeddings


    def get_clustered_embeddings(self, data, n_clusters=1):
        self.model.eval()

        device = self.device
        tokenizer = self.tokenizer
        model = self.model

        # Group the data by label
        label_data = {}
        for example in data:
            label = example['label']
            sentence = example['sentence1']
            if label not in label_data:
                label_data[label] = []
            label_data[label].append(sentence)

        # Encode each sentence and extract the embeddings
        label_embeddings = {}
        for label, sentences in label_data.items():
            sentence_embeddings = []
            for sentence in sentences:
                encoded_dict = tokenizer.encode_plus(
                    sentence,
                    # padding=True,
                    truncation=True,
                    return_tensors='pt',
                    max_length=self.max_length,
                    padding="max_length"
                )
                # Compute token embeddings
                with torch.no_grad():
                    data_input_ids = encoded_dict['input_ids'].to(self.device)
                    data_attention_mask = encoded_dict['attention_mask'].to(self.device)
                    data_token_type_ids = encoded_dict['token_type_ids'].to(self.device)
                    x = data_input_ids, data_token_type_ids, data_attention_mask
                    # model_output = self.model(input_ids=data_input_ids, attention_mask=data_attention_mask)
                    pred, feats, model_output = self.model(x, return_feat=True)

                # Performing mean pooling
                sentence_embedding = self.mean_pooling(model_output, encoded_dict['attention_mask'].to(device))
                sentence_embeddings.append(sentence_embedding)

            label_embeddings[label] = torch.stack(sentence_embeddings)

        author_embeddings = {}
        for author, embeddings in label_embeddings.items():
            embeddings = embeddings.squeeze(1)
            embeddings = embeddings.cpu()
            # print(embeddings.shape)
            kmeans = KMeans(n_clusters=n_clusters)
            kmeans.fit(embeddings)
            author_embeddings[author] = kmeans.cluster_centers_[0]  # Use the centroid of the most prominent cluster
        return author_embeddings


    # Compute and return the ranking for the regions based on the input data
    def compare_to_regions(self, train_loader, test_data, test_loader):

        self.model.eval()
        device = self.device
        model = self.model

        all_embeddings = []
        for i, (x1, x2, x3, y) in enumerate(test_loader):
            x, y = (x1.cuda(), x2.cuda(), x3.cuda()), y.cuda()
            with torch.no_grad():
                pred, feats, model_output = self.model(x, return_feat=True)
                batch_embeddings = feats.flatten(1).cpu()
                all_embeddings.append(batch_embeddings)
        all_embeddings = torch.cat(all_embeddings, dim=0)

        # # input_data = test_data['sentence1']
        # input_data = test_data['content'].tolist()

        # input_embeddings = []
        # for idx, text in enumerate(input_data):
        #     encoded_dict = self.tokenizer.encode_plus(
        #         text,
        #         truncation=True,
        #         padding="max_length",
        #         return_tensors='pt',
        #         max_length=self.max_length
        #     )

        #     # Compute token embeddings
        #     with torch.no_grad():
        #         # encoded_dict = encoded_dict.to(self.device)
        #         # print(encoded_dict)
        #         data_input_ids = encoded_dict['input_ids'].to(self.device)
        #         data_attention_mask = encoded_dict['attention_mask'].to(self.device)
        #         data_token_type_ids = encoded_dict['token_type_ids'].to(self.device)
        #         x = data_input_ids, data_token_type_ids, data_attention_mask
        #         # model_output = self.model(input_ids=data_input_ids, attention_mask=data_attention_mask)
        #         pred, feats, model_output = self.model(x, return_feat=True)
        #         # model_output = self.model(**encoded_dict)

        #     # Performing mean pooling
        #     # sentence_embeddings = self.mean_pooling(model_output, encoded_dict['attention_mask']) # 1 * hidden_dim
        #     sentence_embeddings = feats.flatten(1)
        #     input_embeddings.append(sentence_embeddings)

        # a dictionary, key is the label and values are embeddings for each author (embedding: 1 * hidden_dim)
        region_embeddings = self.get_average_embeddings(train_loader)

        # region_embeddings = self.get_clustered_embeddings(train_data)
        region_embeddings = {label: torch.tensor(embedding).to(self.device) for label, embedding in region_embeddings.items()}
        # print(len(all_embeddings))
        # Compute the similarity between the embeddings of the input data and each region
        similarities = []
        for input_embedding in all_embeddings:
            input_embedding = input_embedding.to(self.device)
            # print(input_embedding.shape)
            input_embedding = torch.reshape(input_embedding, (1, -1)) # 1 * hidden_dim
            # print(input_embedding.shape)

            input_similarity = {}

            for label, region_embedding in region_embeddings.items():
                # print(region_embedding.shape)
                region_embedding = region_embedding.view(1, -1)  # Reshape to (1, num_features)
                similarity = torch.nn.CosineSimilarity(dim=-1)
                input_similarity[label] = similarity(region_embedding, input_embedding).item()

            # Sort the labels by similarity score for each input embedding
            sorted_labels = dict(sorted(input_similarity.items(), key=lambda x: x[1], reverse=True))
            similarities.append(sorted_labels)

        # print(len(similarities))
        # print(len(test_data))
        # each data is stored in a dictionary containing label, sentence and similarity_scores
        test_y = test_data['Target'].tolist()
        test_x = test_data['content'].tolist()
        results_list = []
        for i in range(len(test_data)):
            result = {}
            # result['label'] = test_data['label'][i]
            result['label'] = test_y[i]

            # input_text = test_data['sentence1'][i]
            input_text = test_x[i]
            result['sentence'] = input_text

            similarity_scores = similarities[i]
            result['similarity_scores'] = similarity_scores
            results_list.append(result)

        return results_list

    def get_region_embs(self, train_data, test_data):
        input_data = test_data['sentence1']

        input_embeddings = []
        for idx, text in enumerate(input_data):
            encoded_dict = self.tokenizer.encode_plus(
                text,
                truncation=True,
                padding="max_length",
                return_tensors='pt',
                max_length=self.max_length,
            )

            # Compute token embeddings
            with torch.no_grad():
                encoded_dict = encoded_dict.to(self.device)
                model_output = self.model(**encoded_dict)

            # Performing mean pooling
            sentence_embeddings = self.mean_pooling(model_output, encoded_dict['attention_mask'])
            input_embeddings.append(sentence_embeddings)

        region_embeddings = self.get_average_embeddings(train_data)
        region_embeddings = {label: torch.tensor(embedding).to(self.device) for label, embedding in region_embeddings.items()}
        return input_embeddings

    def extract_labels_predictions(self, results_list, key_index=0):
        labels = []
        predictions = []
        for result in results_list:
            labels.append(result['label'])
            score_dict = result['similarity_scores']
            key = list(score_dict.keys())[key_index]
            predictions.append(key)

        # predictions = [str(i) for i in predictions]
        # labels = [str(i) for i in labels]
        # print(labels)
        # print(predictions)
        # print(self.model.bert.config.label2id)
        # predictions = [self.model.bert.config.label2id[label] for label in predictions]
        # labels = [self.model.bert.config.label2id[label] for label in labels]

        return labels, predictions

    def compute_top_k(self, top_k):
        result_list = self.compare_to_regions(self.train_data, self.test_data)
        labels = []
        predictions = []

        for result in result_list:
            labels.append(result['label'])
            score_dict = result['similarity_scores']
            keys = list(score_dict.keys())[:top_k]
            # print(keys)
            keys = [str(i) for i in keys]
            # keys = [self.model.config.label2id[key] for key in keys]
            predictions.append(keys)
        # print(labels)
        # predictions = [str(i) for i in predictions]
        labels = [str(i) for i in labels]
        # print(predictions)
        # print(labels)
        # labels = [self.model.config.label2id[label] for label in labels]
        # labels = [str(i) for i in labels]

        total_examples = len(labels)
        correct_predictions = 0

        for label, prediction in zip(labels, predictions):
            if label in prediction[:top_k]:
                correct_predictions += 1

        top_k_accuracy = round((correct_predictions / total_examples),3)
        print(f"{top_k} accuracy: {top_k_accuracy}")
        return top_k_accuracy


    def compute_metrics(self):
        self.test_AA_acc()
        result_list = self.compare_to_regions(self.train_loader, self.test_data, self.test_loader)
        references, predictions = self.extract_labels_predictions(result_list)
        predictions = [str(i) for i in predictions]
        references = [str(i) for i in references]
        # print(predictions)
        # print(references)

        metric_names = ["accuracy", "precision", "recall", "f1"]
        metrics = {metric_name: load_metric(metric_name, trust_remote_code=True) for metric_name in metric_names}
        results = {}
        for metric_name, metric in metrics.items():
            if metric_name == "accuracy":
                score = metric.compute(predictions=predictions, references=references)
            else:
                try:
                    score = metric.compute(predictions=predictions, references=references, average="weighted")
                except ValueError:
                    score = metric.compute(predictions=predictions, references=references, average=None)

            print(f"{metric_name} score: {score}")
            results[metric_name] = score

        return results

# main

In [None]:
def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    torch.backends.cudnn.benchmark = True

    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    train_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/train_random100_label_1.csv"
    test_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/test_random100_label_1.csv"
    train_data = get_csv_dataset(train_file)
    test_data = get_csv_dataset(test_file)

    train_data = train_data.rename_column("user_name", "label")
    train_data = train_data.rename_column("prompt", "sentence1")
    test_data = test_data.rename_column("user_name", "label")
    test_data = test_data.rename_column("prompt", "sentence1")

    # train_data = train_data['train'].shuffle().select(range(100))
    # test_data = test_data['train'].shuffle().select(range(100))
    train_data = train_data['train'].shuffle()
    test_data = test_data['train'].shuffle()

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/lcl/diffusiondb100_lcl_coe2_para_bert-base-cased_coe2.0_temp0.1_unit2_epoch30/diffusiondb100_lcl_coe2_para_val0.73512_e24.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=train_data,
                                 test_data=test_data)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=train_data,
                                     test_data=test_data)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    # Train Step
    print("-------------------- Training --------------------")
    dual_encoder.train_batch(train_data, test_data)

    print("-------------------- Saving Model --------------------")
    # Save model
    # torch.save(style_encoder.model.state_dict(), "/home/thao/home/contrastive_aa/disentangle_res/style_encoder.pt")
    # torch.save(content_encoder.encoder.state_dict(), "/home/thao/home/contrastive_aa/disentangle_res/content_encoder.pt")

    # style_encoder.model.save_pretrained("/home/thao/home/contrastive_aa/disentangle_res/style_encoder")
    # style_encoder.tokenizer.save_pretrained("/home/thao/home/contrastive_aa/disentangle_res/style_encoder")

    # content_encoder.encoder.save_pretrained("/home/thao/home/contrastive_aa/disentangle_res/content_encoder")
    # content_encoder.tokenizer.save_pretrained("/home/thao/home/contrastive_aa/disentangle_res/style_encoder")
    # torch.save(dual_encoder.state_dict(), "/home/thao/home/contrastive_aa/disentangle_res/dual_encoder.pt")

    torch.save(style_encoder.model.state_dict(), "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder.pt")
    torch.save(content_encoder.encoder.state_dict(), "/content/drive/MyDrive/msc_project/model/contrastive/club/content_encoder.pt")

    # style_encoder.model.save_pretrained("/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder")
    # style_encoder.tokenizer.save_pretrained("/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder")

    # save_model(exp_dir, f'{id}_val{final_test_acc:.5f}_finale{epoch}.pt', model)

    # content_encoder.encoder.save_pretrained("/content/drive/MyDrive/msc_project/model/contrastive/club/content_encoder")
    # content_encoder.tokenizer.save_pretrained("/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder")
    torch.save(dual_encoder.state_dict(), "/content/drive/MyDrive/msc_project/model/contrastive/club/dual_encoder.pt")

    # Test Step
    print("-------------------- Evaluation --------------------")
    save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder.pt"
    dual_encoder.test_step(style_checkpoint, train_data, test_data, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


In [None]:
def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    torch.backends.cudnn.benchmark = True

    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    train_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/train_random100_label_1.csv"
    test_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/test_random100_label_1.csv"
    train_data = get_csv_dataset(train_file)
    test_data = get_csv_dataset(test_file)

    train_data = train_data.rename_column("user_name", "label")
    train_data = train_data.rename_column("prompt", "sentence1")
    test_data = test_data.rename_column("user_name", "label")
    test_data = test_data.rename_column("prompt", "sentence1")

    # train_data = train_data['train'].shuffle().select(range(100))
    # test_data = test_data['train'].shuffle().select(range(100))
    train_data = train_data['train'].shuffle()
    test_data = test_data['train'].shuffle()

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/lcl/diffusiondb100_lcl_coe2_para_bert-base-cased_coe2.0_temp0.1_unit2_epoch30/diffusiondb100_lcl_coe2_para_val0.73512_e24.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=train_data,
                                 test_data=test_data)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=train_data,
                                     test_data=test_data)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    # Train Step
    print("-------------------- Training --------------------")
    dual_encoder.train_batch(train_data, test_data)

    print("-------------------- Saving Model --------------------")
    # Save model
    # torch.save(style_encoder.model.state_dict(), "/home/thao/home/contrastive_aa/disentangle_res/style_encoder.pt")
    # torch.save(content_encoder.encoder.state_dict(), "/home/thao/home/contrastive_aa/disentangle_res/content_encoder.pt")

    # style_encoder.model.save_pretrained("/home/thao/home/contrastive_aa/disentangle_res/style_encoder")
    # style_encoder.tokenizer.save_pretrained("/home/thao/home/contrastive_aa/disentangle_res/style_encoder")

    # content_encoder.encoder.save_pretrained("/home/thao/home/contrastive_aa/disentangle_res/content_encoder")
    # content_encoder.tokenizer.save_pretrained("/home/thao/home/contrastive_aa/disentangle_res/style_encoder")
    # torch.save(dual_encoder.state_dict(), "/home/thao/home/contrastive_aa/disentangle_res/dual_encoder.pt")

    torch.save(style_encoder.model.state_dict(), "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder.pt")
    torch.save(content_encoder.encoder.state_dict(), "/content/drive/MyDrive/msc_project/model/contrastive/club/content_encoder.pt")

    # style_encoder.model.save_pretrained("/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder")
    # style_encoder.tokenizer.save_pretrained("/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder")

    # save_model(exp_dir, f'{id}_val{final_test_acc:.5f}_finale{epoch}.pt', model)

    # content_encoder.encoder.save_pretrained("/content/drive/MyDrive/msc_project/model/contrastive/club/content_encoder")
    # content_encoder.tokenizer.save_pretrained("/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder")
    torch.save(dual_encoder.state_dict(), "/content/drive/MyDrive/msc_project/model/contrastive/club/dual_encoder.pt")

    # Test Step
    print("-------------------- Evaluation --------------------")
    save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder.pt"
    dual_encoder.test_step(style_checkpoint, train_data, test_data, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


In [None]:
def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    torch.backends.cudnn.benchmark = True

    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    train_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/processed/train_random100_label_1.csv"
    test_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/processed/test_random100_label_1.csv"
    train_data = get_csv_dataset(train_file)
    test_data = get_csv_dataset(test_file)

    train_data = train_data.rename_column("user_name", "label")
    train_data = train_data.rename_column("prompt", "sentence1")
    test_data = test_data.rename_column("user_name", "label")
    test_data = test_data.rename_column("prompt", "sentence1")

    # train_data = train_data['train'].shuffle().select(range(100))
    # test_data = test_data['train'].shuffle().select(range(100))
    train_data = train_data['train'].shuffle()
    test_data = test_data['train'].shuffle()

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/contrax/exp_data/diffusiondb100_supcon_bert-base-cased_coe1_temp0.1_unit2_epoch30/diffusiondb100_supcon_val0.78125_e26.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=train_data,
                                 test_data=test_data)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=train_data,
                                     test_data=test_data)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    # Train Step
    print("-------------------- Training --------------------")
    dual_encoder.train_batch(train_data, test_data)

    print("-------------------- Saving Model --------------------")
    # Save model
    # torch.save(style_encoder.model.state_dict(), "/home/thao/home/contrastive_aa/disentangle_res/style_encoder.pt")
    # torch.save(content_encoder.encoder.state_dict(), "/home/thao/home/contrastive_aa/disentangle_res/content_encoder.pt")

    # style_encoder.model.save_pretrained("/home/thao/home/contrastive_aa/disentangle_res/style_encoder")
    # style_encoder.tokenizer.save_pretrained("/home/thao/home/contrastive_aa/disentangle_res/style_encoder")

    # content_encoder.encoder.save_pretrained("/home/thao/home/contrastive_aa/disentangle_res/content_encoder")
    # content_encoder.tokenizer.save_pretrained("/home/thao/home/contrastive_aa/disentangle_res/style_encoder")
    # torch.save(dual_encoder.state_dict(), "/home/thao/home/contrastive_aa/disentangle_res/dual_encoder.pt")

    torch.save(style_encoder.model.state_dict(), "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder.pt")
    torch.save(content_encoder.encoder.state_dict(), "/content/drive/MyDrive/msc_project/model/contrastive/club/content_encoder.pt")

    # style_encoder.model.save_pretrained("/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder")
    # style_encoder.tokenizer.save_pretrained("/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder")

    # save_model(exp_dir, f'{id}_val{final_test_acc:.5f}_finale{epoch}.pt', model)

    # content_encoder.encoder.save_pretrained("/content/drive/MyDrive/msc_project/model/contrastive/club/content_encoder")
    # content_encoder.tokenizer.save_pretrained("/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder")
    torch.save(dual_encoder.state_dict(), "/content/drive/MyDrive/msc_project/model/contrastive/club/dual_encoder.pt")

    # Test Step
    print("-------------------- Evaluation --------------------")
    save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder.pt"
    dual_encoder.test_step(style_checkpoint, train_data, test_data, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


In [None]:
train_data

# current

In [None]:
torch.cuda.empty_cache()

In [None]:
import pandas as pd
import torch


def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    # torch.backends.cudnn.benchmark = True

    torch.cuda.empty_cache()
    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)


    nlp_train = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/vary/train_random100_150_label_1.csv')
    nlp_val = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/vary/val_random100_150_label_1.csv')
    nlp_test = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/vary/test_random100_150_label_1.csv')
    nlp_train = nlp_train[['prompt', 'user_name']]
    nlp_train.columns = ['content', 'Target']
    nlp_val = nlp_val[['prompt', 'user_name']]
    nlp_val.columns = ['content', 'Target']
    nlp_test = nlp_test[['prompt', 'user_name']]
    nlp_test.columns = ['content', 'Target']



    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/lcl/diffusiondb100_150_lcl_coe1_bert-base-cased_coe1.0_temp0.1_unit2_epoch30/diffusiondb100_150_lcl_coe1_val0.73867_e29.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=nlp_train,
                                 test_data=nlp_test)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=nlp_train,
                                     test_data=nlp_test)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    print("-------------------- Training --------------------")
    dual_encoder.train(nlp_train, nlp_test)

    # Test Step
    # print("-------------------- Evaluation --------------------")
    # save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    # style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder_1.pt"
    # dual_encoder.test_step(style_checkpoint, nlp_train, nlp_test, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


In [None]:
import pandas as pd
import torch


def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    # torch.backends.cudnn.benchmark = True

    torch.cuda.empty_cache()
    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    nlp_train = pd.read_csv('/content/drive/MyDrive/msc_project/data/blogs/processed/blogs50_train.csv')
    nlp_val = pd.read_csv('/content/drive/MyDrive/msc_project/data/blogs/processed/blogs50_AA_val.csv')
    nlp_test = pd.read_csv('/content/drive/MyDrive/msc_project/data/blogs/processed/blogs50_AA_test.csv')
    nlp_train = nlp_train[['text', 'author_id']]
    # nlp_train = nlp_train[['prompt', 'user_label']]
    nlp_train.columns = ['content', 'Target']
    nlp_val = nlp_val[['text', 'author_id']]
    # nlp_val = nlp_val[['prompt', 'user_label']]
    nlp_val.columns = ['content', 'Target']
    nlp_test = nlp_test[['text', 'author_id']]
    # nlp_test = nlp_test[['prompt', 'user_label']]
    nlp_test.columns = ['content', 'Target']

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/lcl/blogs50_lcl_coe1_bert-base-cased_coe1.0_temp0.1_unit2_epoch30/blogs50_lcl_coe1_val0.83212_finale29.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=nlp_train,
                                 test_data=nlp_test)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=nlp_train,
                                     test_data=nlp_test)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    print("-------------------- Training --------------------")
    dual_encoder.train(nlp_train, nlp_test)

    # Test Step
    # print("-------------------- Evaluation --------------------")
    # save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    # style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder_1.pt"
    # dual_encoder.test_step(style_checkpoint, nlp_train, nlp_test, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


In [None]:
import pandas as pd
import torch


def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    # torch.backends.cudnn.benchmark = True

    torch.cuda.empty_cache()
    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    nlp_train = pd.read_csv('/content/drive/MyDrive/msc_project/data/blogs/processed/blogs50_train.csv')
    nlp_val = pd.read_csv('/content/drive/MyDrive/msc_project/data/blogs/processed/blogs50_AA_val.csv')
    nlp_test = pd.read_csv('/content/drive/MyDrive/msc_project/data/blogs/processed/blogs50_AA_test.csv')
    nlp_train = nlp_train[['text', 'author_id']]
    # nlp_train = nlp_train[['prompt', 'user_label']]
    nlp_train.columns = ['content', 'Target']
    nlp_val = nlp_val[['text', 'author_id']]
    # nlp_val = nlp_val[['prompt', 'user_label']]
    nlp_val.columns = ['content', 'Target']
    nlp_test = nlp_test[['text', 'author_id']]
    # nlp_test = nlp_test[['prompt', 'user_label']]
    nlp_test.columns = ['content', 'Target']

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/lcl/blogs50_lcl_coe1_bert-base-cased_coe1.0_temp0.1_unit2_epoch30/blogs50_lcl_coe1_val0.83212_finale29.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=nlp_train,
                                 test_data=nlp_test)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=nlp_train,
                                     test_data=nlp_test)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    print("-------------------- Training --------------------")
    dual_encoder.train(nlp_train, nlp_test)

    # Test Step
    # print("-------------------- Evaluation --------------------")
    # save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    # style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder_1.pt"
    # dual_encoder.test_step(style_checkpoint, nlp_train, nlp_test, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


In [None]:
import pandas as pd
import torch


def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    # torch.backends.cudnn.benchmark = True

    torch.cuda.empty_cache()
    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    nlp_train = pd.read_csv('/content/drive/MyDrive/msc_project/data/blogs/processed/blogs50_train.csv')
    nlp_val = pd.read_csv('/content/drive/MyDrive/msc_project/data/blogs/processed/blogs50_AA_val.csv')
    nlp_test = pd.read_csv('/content/drive/MyDrive/msc_project/data/blogs/processed/blogs50_AA_test.csv')
    nlp_train = nlp_train[['text', 'author_id']]
    # nlp_train = nlp_train[['prompt', 'user_label']]
    nlp_train.columns = ['content', 'Target']
    nlp_val = nlp_val[['text', 'author_id']]
    # nlp_val = nlp_val[['prompt', 'user_label']]
    nlp_val.columns = ['content', 'Target']
    nlp_test = nlp_test[['text', 'author_id']]
    # nlp_test = nlp_test[['prompt', 'user_label']]
    nlp_test.columns = ['content', 'Target']

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/lcl/blogs50_lcl_coe1_bert-base-cased_coe1.0_temp0.1_unit2_epoch30/blogs50_lcl_coe1_val0.83212_finale29.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=nlp_train,
                                 test_data=nlp_test)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=nlp_train,
                                     test_data=nlp_test)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    print("-------------------- Training --------------------")
    dual_encoder.train(nlp_train, nlp_test)

    # Test Step
    # print("-------------------- Evaluation --------------------")
    # save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    # style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder_1.pt"
    # dual_encoder.test_step(style_checkpoint, nlp_train, nlp_test, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


In [None]:
import pandas as pd
import torch


def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    # torch.backends.cudnn.benchmark = True

    torch.cuda.empty_cache()
    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    nlp_train = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/train_random100_label_1.csv')
    nlp_val = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/val_random100_label_1.csv')
    nlp_test = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/test_random100_label_1.csv')
    nlp_train = nlp_train[['prompt', 'user_name']]
    nlp_train.columns = ['content', 'Target']
    nlp_val = nlp_val[['prompt', 'user_name']]
    nlp_val.columns = ['content', 'Target']
    nlp_test = nlp_test[['prompt', 'user_name']]
    nlp_test.columns = ['content', 'Target']

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/lcl/diffusiondb100_lcl_coe1_para_bert-base-cased_coe1.0_temp0.1_unit2_epoch30/diffusiondb100_lcl_coe1_para_val0.73264_e16.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=nlp_train,
                                 test_data=nlp_test)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=nlp_train,
                                     test_data=nlp_test)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    print("-------------------- Training --------------------")
    dual_encoder.train(nlp_train, nlp_test)

    # Test Step
    print("-------------------- Evaluation --------------------")
    save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder_1.pt"
    dual_encoder.test_step(style_checkpoint, nlp_train, nlp_test, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


In [None]:
import pandas as pd
import torch


def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    # torch.backends.cudnn.benchmark = True

    torch.cuda.empty_cache()
    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    nlp_train = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/train_topicseparate100_label_1.csv')
    nlp_val = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/val_topicseparate100_label_1.csv')
    nlp_test = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/test_topicseparate100_label_1.csv')
    # nlp_train = nlp_train[['prompt', 'user_name']]
    nlp_train = nlp_train[['prompt', 'user_label']]
    nlp_train.columns = ['content', 'Target']
    # nlp_val = nlp_val[['prompt', 'user_name']]
    nlp_val = nlp_val[['prompt', 'user_label']]
    nlp_val.columns = ['content', 'Target']
    # nlp_test = nlp_test[['prompt', 'user_name']]
    nlp_test = nlp_test[['prompt', 'user_label']]
    nlp_test.columns = ['content', 'Target']
    # print(nlp_train['Target'])

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/lcl/diffusiondb100_lcl_coe1_para_topic_bert-base-cased_coe1.0_temp0.1_unit2_epoch30/diffusiondb100_lcl_coe1_para_topic_val0.48859_e29.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=nlp_train,
                                 test_data=nlp_test)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=nlp_train,
                                     test_data=nlp_test)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    print("-------------------- Training --------------------")
    dual_encoder.train(nlp_train, nlp_test)

    # Test Step
    # print("-------------------- Evaluation --------------------")
    # save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    # style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder_1.pt"
    # dual_encoder.test_step(style_checkpoint, nlp_train, nlp_test, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


In [None]:
torch.cuda.empty_cache()

In [None]:
import pandas as pd
import torch


def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    # torch.backends.cudnn.benchmark = True

    torch.cuda.empty_cache()
    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    nlp_train = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/train_random100_label_1.csv')
    nlp_val = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/val_random100_label_1.csv')
    nlp_test = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/test_random100_label_1.csv')
    nlp_train = nlp_train[['prompt', 'user_name']]
    nlp_train.columns = ['content', 'Target']
    nlp_val = nlp_val[['prompt', 'user_name']]
    nlp_val.columns = ['content', 'Target']
    nlp_test = nlp_test[['prompt', 'user_name']]
    nlp_test.columns = ['content', 'Target']

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/lcl/diffusiondb100_lcl_coe1_para_bert-base-cased_coe1.0_temp0.1_unit2_epoch30/diffusiondb100_lcl_coe1_para_val0.73264_e16.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=nlp_train,
                                 test_data=nlp_test)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=nlp_train,
                                     test_data=nlp_test)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    print("-------------------- Training --------------------")
    dual_encoder.train(nlp_train, nlp_test)

    # Test Step
    # print("-------------------- Evaluation --------------------")
    # save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    # style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder_1.pt"
    # dual_encoder.test_step(style_checkpoint, nlp_train, nlp_test, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


In [None]:
import pandas as pd
import torch


def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    # torch.backends.cudnn.benchmark = True

    torch.cuda.empty_cache()
    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    nlp_train = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/train_random100_label_1.csv')
    nlp_val = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/val_random100_label_1.csv')
    nlp_test = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/test_random100_label_1.csv')
    nlp_train = nlp_train[['prompt', 'user_name']]
    nlp_train.columns = ['content', 'Target']
    nlp_val = nlp_val[['prompt', 'user_name']]
    nlp_val.columns = ['content', 'Target']
    nlp_test = nlp_test[['prompt', 'user_name']]
    nlp_test.columns = ['content', 'Target']

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/lcl/diffusiondb100_lcl_coe1_para_bert-base-cased_coe1.0_temp0.1_unit2_epoch30/diffusiondb100_lcl_coe1_para_val0.73264_e16.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=nlp_train,
                                 test_data=nlp_test)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=nlp_train,
                                     test_data=nlp_test)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    print("-------------------- Training --------------------")
    dual_encoder.train(nlp_train, nlp_test)

    # Test Step
    print("-------------------- Evaluation --------------------")
    save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder_1.pt"
    dual_encoder.test_step(style_checkpoint, nlp_train, nlp_test, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


# others

In [None]:
import pandas as pd
import torch


def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    # torch.backends.cudnn.benchmark = True

    torch.cuda.empty_cache()
    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    nlp_train = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/train_random100_label_1.csv')
    nlp_val = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/val_random100_label_1.csv')
    nlp_test = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/test_random100_label_1.csv')
    nlp_train = nlp_train[['prompt', 'user_name']]
    nlp_train.columns = ['content', 'Target']
    nlp_val = nlp_val[['prompt', 'user_name']]
    nlp_val.columns = ['content', 'Target']
    nlp_test = nlp_test[['prompt', 'user_name']]
    nlp_test.columns = ['content', 'Target']

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/contrax/exp_data/diffusiondb100_supcon_para_bert-base-cased_coe1_temp0.1_unit2_epoch30/diffusiondb100_supcon_para_val0.72321_e29.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=nlp_train,
                                 test_data=nlp_test)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=nlp_train,
                                     test_data=nlp_test)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    print("-------------------- Training --------------------")
    dual_encoder.train(nlp_train, nlp_test)

    # Test Step
    print("-------------------- Evaluation --------------------")
    save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder_1.pt"
    dual_encoder.test_step(style_checkpoint, nlp_train, nlp_test, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


In [None]:
import pandas as pd
import torch


def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    # torch.backends.cudnn.benchmark = True

    torch.cuda.empty_cache()
    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    nlp_train = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/train_random100_label_1.csv')
    nlp_val = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/val_random100_label_1.csv')
    nlp_test = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/test_random100_label_1.csv')
    nlp_train = nlp_train[['prompt', 'user_name']]
    nlp_train.columns = ['content', 'Target']
    nlp_val = nlp_val[['prompt', 'user_name']]
    nlp_val.columns = ['content', 'Target']
    nlp_test = nlp_test[['prompt', 'user_name']]
    nlp_test.columns = ['content', 'Target']

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/lcl/diffusiondb100_lcl_coe2_para_bert-base-cased_coe2.0_temp0.1_unit2_epoch30/diffusiondb100_lcl_coe2_para_val0.73512_e24.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=nlp_train,
                                 test_data=nlp_test)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=nlp_train,
                                     test_data=nlp_test)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    print("-------------------- Training --------------------")
    dual_encoder.train(nlp_train, nlp_test)

    # Test Step
    print("-------------------- Evaluation --------------------")
    save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder_1.pt"
    dual_encoder.test_step(style_checkpoint, nlp_train, nlp_test, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


In [None]:
def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    torch.backends.cudnn.benchmark = True

    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    train_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/processed/train_random100_label_1.csv"
    test_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/processed/test_random100_label_1.csv"
    train_data = get_csv_dataset(train_file)
    test_data = get_csv_dataset(test_file)

    train_data = train_data.rename_column("user_name", "label")
    train_data = train_data.rename_column("prompt", "sentence1")
    test_data = test_data.rename_column("user_name", "label")
    test_data = test_data.rename_column("prompt", "sentence1")

    # train_data = train_data['train'].shuffle().select(range(100))
    # test_data = test_data['train'].shuffle().select(range(100))
    train_data = train_data['train'].shuffle()
    test_data = test_data['train'].shuffle()

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/contrax/exp_data/diffusiondb100_supcon_bert-base-cased_coe1_temp0.1_unit2_epoch30/diffusiondb100_supcon_val0.78125_e26.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=train_data,
                                 test_data=test_data)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=train_data,
                                     test_data=test_data)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    # Test Step
    print("-------------------- Evaluation --------------------")
    save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder_1.pt"
    dual_encoder.test_step(style_checkpoint, train_data, test_data, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


In [None]:
def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    torch.backends.cudnn.benchmark = True

    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    train_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/train_random100_label_1.csv"
    test_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/test_random100_label_1.csv"
    train_data = get_csv_dataset(train_file)
    test_data = get_csv_dataset(test_file)

    train_data = train_data.rename_column("user_name", "label")
    train_data = train_data.rename_column("prompt", "sentence1")
    test_data = test_data.rename_column("user_name", "label")
    test_data = test_data.rename_column("prompt", "sentence1")

    # train_data = train_data['train'].shuffle().select(range(100))
    # test_data = test_data['train'].shuffle().select(range(100))
    train_data = train_data['train'].shuffle()
    test_data = test_data['train'].shuffle()

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/contrax/exp_data/diffusiondb100_supcon_para_bert-base-cased_coe1_temp0.1_unit2_epoch30/diffusiondb100_supcon_para_val0.72321_e29.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=train_data,
                                 test_data=test_data)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=train_data,
                                     test_data=test_data)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    # Test Step
    print("-------------------- Evaluation --------------------")
    save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    # style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder.pt"
    dual_encoder.test_step(style_checkpoint, train_data, test_data, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)

# average
# accuracy score: {'accuracy': 0.7085}
# precision score: {'precision': 0.7262531009903335}
# recall score: {'recall': 0.7085}
# f1 score: {'f1': 0.711030859836859}

# accuracy score: {'accuracy': 0.7155}
# precision score: {'precision': 0.7354965007653969}
# recall score: {'recall': 0.7155}
# f1 score: {'f1': 0.719259827630641}

# kmeans
# accuracy score: {'accuracy': 0.7085}
# precision score: {'precision': 0.7262531009903335}
# recall score: {'recall': 0.7085}
# f1 score: {'f1': 0.711030859836859}

In [None]:
def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    torch.backends.cudnn.benchmark = True

    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    train_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/train_random100_label_1.csv"
    test_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/test_random100_label_1.csv"
    train_data = get_csv_dataset(train_file)
    test_data = get_csv_dataset(test_file)

    train_data = train_data.rename_column("user_name", "label")
    train_data = train_data.rename_column("prompt", "sentence1")
    test_data = test_data.rename_column("user_name", "label")
    test_data = test_data.rename_column("prompt", "sentence1")

    # train_data = train_data['train'].shuffle().select(range(100))
    # test_data = test_data['train'].shuffle().select(range(100))
    train_data = train_data['train'].shuffle()
    test_data = test_data['train'].shuffle()

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/lcl/diffusiondb100_lcl_para_bert-base-cased_coe1.0_temp0.1_unit2_epoch30/diffusiondb100_lcl_para_val0.72073_e29.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=train_data,
                                 test_data=test_data)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=train_data,
                                     test_data=test_data)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    # Test Step
    print("-------------------- Evaluation --------------------")
    save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    # style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder.pt"
    dual_encoder.test_step(style_checkpoint, train_data, test_data, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)

# average
# accuracy score: {'accuracy': 0.6865}
# precision score: {'precision': 0.6929100874109159}
# recall score: {'recall': 0.6865}
# f1 score: {'f1': 0.6807407783602909}

# accuracy score: {'accuracy': 0.71}
# precision score: {'precision': 0.7129844761372993}
# recall score: {'recall': 0.71}
# f1 score: {'f1': 0.7040015279480824}

In [None]:
def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    torch.backends.cudnn.benchmark = True

    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    train_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/train_random100_label_1.csv"
    test_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/test_random100_label_1.csv"
    train_data = get_csv_dataset(train_file)
    test_data = get_csv_dataset(test_file)

    train_data = train_data.rename_column("user_name", "label")
    train_data = train_data.rename_column("prompt", "sentence1")
    test_data = test_data.rename_column("user_name", "label")
    test_data = test_data.rename_column("prompt", "sentence1")

    # train_data = train_data['train'].shuffle().select(range(100))
    # test_data = test_data['train'].shuffle().select(range(100))
    train_data = train_data['train'].shuffle()
    test_data = test_data['train'].shuffle()

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/contrax/exp_data/diffusiondb100_cls_para_bert-base-cased_coe0.0_temp0.1_unit2_epoch30/diffusiondb100_cls_para_val0.72073_e24.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=train_data,
                                 test_data=test_data)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=train_data,
                                     test_data=test_data)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    # Test Step
    print("-------------------- Evaluation --------------------")
    save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    # style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder.pt"
    dual_encoder.test_step(style_checkpoint, train_data, test_data, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


In [None]:
def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    torch.backends.cudnn.benchmark = True

    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    train_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/train_random100_label_1.csv"
    test_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/test_random100_label_1.csv"
    train_data = get_csv_dataset(train_file)
    test_data = get_csv_dataset(test_file)

    train_data = train_data.rename_column("user_name", "label")
    train_data = train_data.rename_column("prompt", "sentence1")
    test_data = test_data.rename_column("user_name", "label")
    test_data = test_data.rename_column("prompt", "sentence1")

    # train_data = train_data['train'].shuffle().select(range(100))
    # test_data = test_data['train'].shuffle().select(range(100))
    train_data = train_data['train'].shuffle()
    test_data = test_data['train'].shuffle()

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/lcl/diffusiondb100_lcl_coe2_para_bert-base-cased_coe2.0_temp0.1_unit2_epoch30/diffusiondb100_lcl_coe2_para_val0.73512_e24.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=train_data,
                                 test_data=test_data)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=train_data,
                                     test_data=test_data)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    # Test Step
    print("-------------------- Evaluation --------------------")
    save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    # style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder.pt"
    dual_encoder.test_step(style_checkpoint, train_data, test_data, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


In [None]:
import pandas as pd

def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    torch.backends.cudnn.benchmark = True

    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    # train_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/train_random100_label_1.csv"
    # test_file = "/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/test_random100_label_1.csv"
    # train_data = get_csv_dataset(train_file)
    # test_data = get_csv_dataset(test_file)

    # train_data = train_data.rename_column("user_name", "label")
    # train_data = train_data.rename_column("prompt", "sentence1")
    # test_data = test_data.rename_column("user_name", "label")
    # test_data = test_data.rename_column("prompt", "sentence1")

    nlp_train = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/train_random100_label_1.csv')
    nlp_val = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/val_random100_label_1.csv')
    nlp_test = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/test_random100_label_1.csv')
    nlp_train = nlp_train[['prompt', 'user_name']]
    nlp_train.columns = ['content', 'Target']
    nlp_val = nlp_val[['prompt', 'user_name']]
    nlp_val.columns = ['content', 'Target']
    nlp_test = nlp_test[['prompt', 'user_name']]
    nlp_test.columns = ['content', 'Target']

    # train_data = train_data['train'].shuffle().select(range(100))
    # test_data = test_data['train'].shuffle().select(range(100))
    # train_data = train_data['train'].shuffle()
    # test_data = test_data['train'].shuffle()

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/lcl/diffusiondb100_lcl_coe1_para_bert-base-cased_coe1.0_temp0.1_unit2_epoch30/diffusiondb100_lcl_coe1_para_val0.73264_e16.pt'
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=nlp_train,
                                 test_data=nlp_test)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=nlp_train,
                                     test_data=nlp_test)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    # Test Step
    print("-------------------- Evaluation --------------------")
    save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    # style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder_1.pt"
    dual_encoder.test_step(style_checkpoint, nlp_train, nlp_test, save_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


# current

In [None]:
import pandas as pd
import torch


def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    # torch.backends.cudnn.benchmark = True

    torch.cuda.empty_cache()
    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    nlp_train = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/train_random100_label_1.csv')
    nlp_val = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/val_random100_label_1.csv')
    nlp_test = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/test_random100_label_1.csv')
    nlp_train = nlp_train[['prompt', 'user_name']]
    nlp_train.columns = ['content', 'Target']
    nlp_val = nlp_val[['prompt', 'user_name']]
    nlp_val.columns = ['content', 'Target']
    nlp_test = nlp_test[['prompt', 'user_name']]
    nlp_test.columns = ['content', 'Target']

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club_0.73/style_encoder_supcon_18.pt"    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_test_res"
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=nlp_train,
                                 test_data=nlp_test)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=nlp_train,
                                     test_data=nlp_test)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    # Test Step
    print("-------------------- Evaluation --------------------")
    save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club_0.73/style_encoder_supcon_18.pt"
    dual_encoder.test_step(style_checkpoint, nlp_train, nlp_test, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score

# Example lists
y_true = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 51, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 52, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 56, 56, 56, 56, 56, 56, 56, 56, 56, 56, 56, 56, 56, 56, 56, 56, 56, 56, 56, 56, 57, 57, 57, 57, 57, 57, 57, 57, 57, 57, 57, 57, 57, 57, 57, 57, 57, 57, 57, 57, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 66, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 68, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 74, 75, 75, 75, 75, 75, 75, 75, 75, 75, 75, 75, 75, 75, 75, 75, 75, 75, 75, 75, 75, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 82, 82, 82, 82, 82, 82, 82, 82, 82, 82, 82, 82, 82, 82, 82, 82, 82, 82, 82, 82, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 91, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 93, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 97, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99]


y_pred = [0, 23, 81, 0, 78, 18, 0, 0, 69, 18, 0, 0, 36, 36, 0, 41, 52, 82, 0, 38, 1, 1, 1, 1, 84, 39, 1, 1, 29, 18, 1, 81, 1, 1, 1, 6, 1, 33, 1, 42, 2, 2, 2, 48, 2, 70, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 93, 2, 4, 2, 3, 3, 3, 3, 3, 3, 3, 39, 33, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 68, 4, 94, 4, 4, 4, 4, 11, 4, 4, 4, 4, 41, 28, 53, 4, 4, 4, 4, 0, 52, 5, 5, 50, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 21, 72, 6, 6, 39, 6, 7, 6, 72, 75, 6, 6, 6, 6, 6, 72, 6, 6, 6, 6, 7, 7, 49, 74, 29, 23, 68, 7, 7, 74, 7, 7, 32, 7, 7, 94, 7, 7, 51, 7, 39, 8, 81, 29, 8, 8, 8, 8, 8, 47, 29, 68, 8, 7, 8, 81, 8, 73, 8, 57, 9, 9, 9, 9, 9, 9, 9, 82, 9, 9, 9, 9, 9, 9, 9, 58, 9, 9, 9, 29, 52, 85, 85, 39, 10, 10, 10, 10, 10, 43, 35, 29, 53, 26, 10, 90, 10, 10, 48, 46, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 53, 12, 12, 12, 49, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 80, 13, 13, 13, 13, 49, 13, 13, 13, 13, 84, 13, 26, 13, 29, 13, 13, 53, 13, 14, 14, 16, 65, 30, 47, 26, 14, 14, 31, 82, 37, 8, 15, 35, 43, 84, 14, 14, 14, 15, 78, 15, 35, 15, 15, 15, 15, 15, 15, 41, 15, 68, 15, 15, 15, 15, 15, 15, 68, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 17, 48, 48, 17, 17, 17, 51, 31, 17, 17, 48, 17, 17, 17, 46, 68, 58, 41, 17, 17, 4, 18, 36, 86, 25, 18, 18, 18, 81, 18, 18, 74, 18, 18, 18, 18, 25, 18, 64, 47, 19, 86, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 34, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 73, 20, 21, 34, 98, 85, 29, 21, 21, 21, 21, 21, 21, 73, 21, 21, 21, 21, 34, 21, 21, 94, 47, 22, 22, 22, 22, 22, 22, 22, 22, 22, 33, 22, 22, 22, 22, 22, 22, 60, 22, 22, 14, 23, 23, 77, 23, 52, 67, 68, 23, 23, 13, 70, 23, 68, 23, 61, 23, 23, 52, 23, 24, 13, 24, 24, 24, 10, 24, 82, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25, 25, 6, 25, 25, 7, 99, 61, 90, 25, 47, 25, 32, 25, 25, 25, 62, 25, 26, 26, 26, 26, 3, 26, 26, 26, 26, 81, 26, 26, 53, 26, 26, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 32, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 86, 28, 28, 28, 28, 10, 45, 6, 86, 29, 29, 10, 29, 29, 72, 33, 29, 47, 29, 29, 29, 53, 65, 29, 45, 30, 0, 30, 68, 45, 18, 30, 30, 30, 30, 41, 62, 30, 78, 30, 30, 30, 39, 0, 33, 31, 31, 31, 31, 31, 31, 84, 31, 31, 31, 10, 31, 31, 31, 31, 31, 31, 31, 18, 31, 52, 41, 32, 32, 7, 21, 32, 32, 32, 44, 32, 32, 32, 32, 94, 32, 32, 32, 32, 1, 33, 33, 33, 33, 33, 78, 33, 67, 33, 33, 67, 33, 41, 33, 47, 78, 33, 67, 29, 33, 34, 34, 77, 86, 18, 34, 34, 34, 34, 34, 38, 34, 18, 34, 46, 14, 34, 68, 34, 51, 18, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 33, 14, 35, 35, 35, 35, 35, 73, 36, 36, 36, 36, 36, 36, 36, 36, 8, 6, 36, 36, 70, 36, 36, 36, 36, 80, 36, 37, 37, 37, 37, 37, 81, 70, 37, 37, 34, 37, 52, 37, 73, 37, 37, 37, 37, 37, 52, 91, 38, 38, 38, 38, 38, 47, 38, 78, 38, 18, 38, 1, 38, 38, 38, 38, 38, 29, 0, 3, 39, 15, 7, 3, 39, 39, 29, 39, 39, 39, 4, 94, 10, 39, 21, 39, 39, 45, 45, 40, 40, 40, 89, 40, 32, 84, 38, 40, 40, 40, 40, 40, 93, 29, 40, 40, 40, 40, 40, 6, 41, 75, 1, 41, 6, 41, 26, 54, 41, 75, 41, 86, 41, 58, 41, 41, 0, 41, 38, 42, 42, 42, 42, 42, 42, 42, 42, 26, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 96, 43, 51, 73, 43, 43, 46, 43, 43, 58, 43, 46, 43, 43, 43, 43, 85, 43, 43, 14, 73, 44, 44, 44, 44, 44, 73, 44, 44, 73, 44, 74, 73, 82, 44, 26, 32, 21, 44, 47, 62, 94, 21, 73, 40, 45, 21, 3, 39, 45, 85, 45, 21, 73, 45, 45, 45, 23, 45, 73, 46, 4, 61, 46, 46, 84, 46, 46, 8, 46, 19, 31, 23, 46, 53, 10, 82, 46, 46, 46, 17, 30, 26, 47, 44, 47, 47, 56, 67, 47, 47, 47, 47, 47, 69, 84, 39, 47, 66, 48, 94, 48, 48, 48, 10, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 10, 49, 49, 49, 49, 49, 49, 49, 49, 15, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 51, 51, 51, 61, 51, 51, 51, 51, 67, 51, 51, 51, 32, 51, 51, 85, 51, 51, 51, 51, 23, 52, 52, 5, 13, 52, 52, 52, 48, 52, 52, 84, 52, 52, 52, 49, 52, 52, 53, 93, 53, 53, 5, 53, 50, 53, 53, 61, 53, 90, 53, 33, 53, 53, 53, 53, 90, 35, 53, 94, 54, 67, 54, 35, 1, 65, 54, 54, 54, 7, 54, 54, 68, 32, 54, 97, 54, 54, 54, 32, 55, 55, 49, 55, 55, 55, 55, 55, 55, 55, 55, 55, 94, 55, 29, 55, 55, 55, 94, 55, 56, 42, 46, 56, 56, 56, 56, 56, 56, 56, 73, 56, 47, 56, 56, 56, 56, 56, 56, 56, 57, 57, 57, 57, 57, 57, 79, 57, 57, 57, 57, 57, 57, 57, 57, 57, 57, 57, 31, 57, 62, 39, 58, 39, 65, 58, 58, 58, 58, 65, 58, 86, 58, 58, 58, 58, 98, 58, 58, 58, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 69, 59, 59, 59, 59, 60, 60, 60, 94, 60, 60, 60, 34, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 61, 61, 61, 44, 61, 61, 61, 85, 61, 61, 16, 64, 61, 61, 7, 73, 61, 74, 64, 70, 62, 62, 39, 81, 62, 62, 62, 62, 85, 20, 31, 62, 62, 62, 62, 62, 1, 85, 62, 62, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 63, 64, 64, 64, 85, 64, 64, 7, 64, 64, 66, 3, 64, 64, 64, 41, 64, 64, 64, 64, 64, 65, 65, 86, 65, 41, 65, 65, 65, 65, 65, 65, 65, 65, 65, 95, 81, 65, 65, 65, 73, 66, 66, 66, 66, 66, 13, 66, 66, 33, 66, 66, 66, 66, 66, 66, 79, 66, 66, 66, 66, 74, 74, 67, 67, 67, 7, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 74, 67, 98, 23, 94, 75, 86, 68, 68, 49, 46, 91, 48, 68, 48, 81, 48, 58, 68, 68, 73, 65, 69, 69, 69, 78, 69, 21, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 3, 69, 69, 3, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 70, 31, 55, 70, 70, 57, 9, 70, 70, 13, 14, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 93, 71, 71, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 78, 72, 73, 38, 73, 73, 73, 13, 73, 83, 94, 73, 73, 73, 73, 47, 73, 73, 73, 21, 21, 73, 74, 74, 74, 74, 74, 74, 74, 61, 33, 74, 74, 74, 74, 74, 74, 74, 74, 25, 74, 74, 62, 75, 75, 75, 75, 75, 75, 67, 75, 13, 75, 70, 75, 61, 51, 93, 71, 75, 61, 10, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 76, 17, 76, 76, 76, 76, 76, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 78, 48, 78, 78, 78, 78, 78, 82, 78, 82, 78, 78, 78, 78, 61, 72, 78, 78, 78, 31, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 50, 80, 80, 80, 80, 80, 80, 80, 81, 81, 57, 81, 57, 51, 37, 81, 81, 81, 78, 81, 81, 31, 81, 81, 81, 81, 81, 98, 29, 55, 82, 82, 82, 78, 37, 54, 7, 82, 82, 46, 57, 82, 10, 82, 61, 40, 82, 82, 73, 83, 29, 83, 83, 24, 83, 83, 83, 73, 83, 83, 83, 29, 83, 85, 83, 83, 83, 94, 68, 61, 84, 84, 33, 84, 84, 84, 84, 84, 84, 16, 84, 84, 84, 51, 32, 84, 61, 5, 85, 7, 73, 45, 7, 85, 85, 86, 85, 85, 85, 85, 85, 3, 85, 69, 85, 29, 81, 85, 86, 86, 21, 86, 42, 86, 86, 86, 51, 8, 86, 86, 86, 86, 59, 86, 58, 86, 8, 42, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 87, 88, 88, 88, 88, 47, 88, 83, 88, 88, 88, 88, 88, 88, 88, 88, 88, 74, 88, 88, 88, 89, 26, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 89, 97, 89, 89, 89, 89, 90, 90, 90, 90, 90, 90, 10, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 90, 91, 91, 91, 91, 91, 26, 91, 91, 91, 91, 18, 91, 91, 91, 91, 91, 91, 91, 91, 91, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 41, 92, 92, 92, 79, 92, 93, 93, 52, 93, 93, 93, 93, 93, 93, 51, 93, 93, 93, 93, 93, 93, 93, 94, 88, 93, 18, 21, 34, 85, 94, 94, 41, 90, 94, 94, 94, 79, 94, 94, 98, 94, 94, 10, 94, 94, 95, 67, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 93, 81, 95, 96, 1, 94, 96, 58, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 61, 96, 94, 96, 94, 97, 97, 97, 97, 49, 97, 97, 97, 97, 97, 97, 97, 97, 97, 90, 97, 97, 97, 97, 97, 73, 98, 98, 98, 98, 98, 56, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 58, 99, 99, 99, 99, 44, 99, 99, 99, 99, 99, 99, 44, 99, 99, 99, 99, 99, 99, 99, 99]



# Calculate Precision, Recall, and F1 Score
precision = precision_score(y_true, y_pred, average='macro')
recall = recall_score(y_true, y_pred, average='macro')
f1 = f1_score(y_true, y_pred, average='macro')

print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')


In [None]:
import pandas as pd
import torch


def main(
    batch_size: int,
    num_epochs: int,
    log_step: int,
    num_iters: int,
    # save_dir: str,
    ):
    # torch.backends.cudnn.benchmark = True

    torch.cuda.empty_cache()
    # Preare data
    # data_files = {
    #     "train": "../../data/twitter_micro_train.json",
    #     "test": "../../data/twitter_micro_test.json"}
    # train_data, test_data = get_dataset(data_files=data_files)

    nlp_train = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/train_random100_label_1.csv')
    nlp_val = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/val_random100_label_1.csv')
    nlp_test = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/paraphrased/test_random100_label_1.csv')
    nlp_train = nlp_train[['prompt', 'user_name']]
    nlp_train.columns = ['content', 'Target']
    nlp_val = nlp_val[['prompt', 'user_name']]
    nlp_val.columns = ['content', 'Target']
    nlp_test = nlp_test[['prompt', 'user_name']]
    nlp_test.columns = ['content', 'Target']

    # train_data, test_data = get_samples(data_files=data_files,len_train_sample=10000, len_test_sample=10000)

    # Load pretrained model
    # style_checkpoint = "/home/thao/home/contrastive_aa/AA_region_cls"
    style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/lcl/diffusiondb100_lcl_coe1_para_bert-base-cased_coe1.0_temp0.1_unit2_epoch30/diffusiondb100_lcl_coe1_para_val0.73264_e16.pt'
    content_checkpoint = "bert-base-cased"

    # Build model
    style_encoder = StyleEncoder(checkpoint=style_checkpoint,
                                 train_data=nlp_train,
                                 test_data=nlp_test)

    content_encoder = ContentEncoder(checkpoint=content_checkpoint,
                                     train_data=nlp_train,
                                     test_data=nlp_test)

    dual_encoder = DualEncoder(content_encoder=content_encoder,
                               style_encoder=style_encoder,
                               num_epochs=num_epochs,
                               batch_size=batch_size,
                               num_iters=num_iters,
                               log_step=log_step)

    # Test Step
    print("-------------------- Evaluation --------------------")
    save_dir = "/content/drive/MyDrive/msc_project/model/contrastive/club/result.json"
    # style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club_0.73/style_encoder_supcon_18.pt"
    dual_encoder.test_step(style_checkpoint, nlp_train, nlp_test, save_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_iters", type=int, default=100)
    # parser.add_argument("--save_dir", type=str)
    parser.add_argument("--log_step", type=int, default=10)
    # parser.add_argument("--do_train", action="store_true")
    # parser.add_argument("--do_evaluation", action="store_true")

    training_args = {
      'batch_size': 32,
      'num_epochs': 20,
      'num_iters': 1,
      'log_step': 10,
    }

    # parse args
    # args = parser.parse_args(**vars(training_args))

    main(**training_args)


# test1

In [None]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

import math

class CLUBForCategorical(nn.Module): # Update 04/27/2022
    '''
    This class provide a CLUB estimator to calculate MI upper bound between vector-like embeddings and categorical labels.
    Estimate I(X,Y), where X is continuous vector and Y is discrete label.
    '''
    def __init__(self, input_dim, label_num, hidden_size=None):
        '''
        input_dim : the dimension of input embeddings
        label_num : the number of categorical labels
        '''
        super().__init__()

        if hidden_size is None:
            self.variational_net = nn.Linear(input_dim, label_num)
        else:
            self.variational_net = nn.Sequential(
                nn.Linear(input_dim, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, label_num)
            )

    def forward(self, inputs, labels):
        '''
        inputs : shape [batch_size, input_dim], a batch of embeddings
        labels : shape [batch_size], a batch of label index
        '''
        logits = self.variational_net(inputs)  #[sample_size, label_num]

        # log of conditional probability of positive sample pairs
        #positive = - nn.functional.cross_entropy(logits, labels, reduction='none')
        sample_size, label_num = logits.shape

        logits_extend = logits.unsqueeze(1).repeat(1, sample_size, 1)  # shape [sample_size, sample_size, label_num]
        labels_extend = labels.unsqueeze(0).repeat(sample_size, 1)     # shape [sample_size, sample_size]

        # log of conditional probability of negative sample pairs
        log_mat = - nn.functional.cross_entropy(
            logits_extend.reshape(-1, label_num),
            labels_extend.reshape(-1, ),
            reduction='none'
        )

        log_mat = log_mat.reshape(sample_size, sample_size)
        positive = torch.diag(log_mat).mean()
        negative = log_mat.mean()
        return positive - negative

    def loglikeli(self, inputs, labels):
        logits = self.variational_net(inputs)
        return - nn.functional.cross_entropy(logits, labels)

    def learning_loss(self, inputs, labels):
        return - self.loglikeli(inputs, labels)


class CLUB(nn.Module):  # CLUB: Mutual Information Contrastive Learning Upper Bound
    '''
        This class provides the CLUB estimation to I(X,Y)
        Method:
            forward() :      provides the estimation with input samples
            loglikeli() :   provides the log-likelihood of the approximation q(Y|X) with input samples
        Arguments:
            x_dim, y_dim :         the dimensions of samples from X, Y respectively
            hidden_size :          the dimension of the hidden layer of the approximation network q(Y|X)
            x_samples, y_samples : samples from X and Y, having shape [sample_size, x_dim/y_dim]
    '''
    def __init__(self, x_dim, y_dim, hidden_size):
        super(CLUB, self).__init__()
        # p_mu outputs mean of q(Y|X)
        #print("create CLUB with dim {}, {}, hiddensize {}".format(x_dim, y_dim, hidden_size))
        self.p_mu = nn.Sequential(nn.Linear(x_dim, hidden_size//2),
                                       nn.ReLU(),
                                       nn.Linear(hidden_size//2, y_dim))
        # p_logvar outputs log of variance of q(Y|X)
        self.p_logvar = nn.Sequential(nn.Linear(x_dim, hidden_size//2),
                                       nn.ReLU(),
                                       nn.Linear(hidden_size//2, y_dim),
                                       nn.Tanh())

    def get_mu_logvar(self, x_samples):
        mu = self.p_mu(x_samples)
        logvar = self.p_logvar(x_samples)
        return mu, logvar

    def forward(self, x_samples, y_samples):
        mu, logvar = self.get_mu_logvar(x_samples)

        # log of conditional probability of positive sample pairs
        positive = - (mu - y_samples)**2 /2./logvar.exp()

        prediction_1 = mu.unsqueeze(1)          # shape [nsample,1,dim]
        y_samples_1 = y_samples.unsqueeze(0)    # shape [1,nsample,dim]

        # log of conditional probability of negative sample pairs
        negative = - ((y_samples_1 - prediction_1)**2).mean(dim=1)/2./logvar.exp()

        return (positive.sum(dim = -1) - negative.sum(dim = -1)).mean()

    def loglikeli(self, x_samples, y_samples): # unnormalized loglikelihood
        mu, logvar = self.get_mu_logvar(x_samples)
        return (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0)

    def learning_loss(self, x_samples, y_samples):
        return - self.loglikeli(x_samples, y_samples)


class CLUBMean(nn.Module):  # Set variance of q(y|x) to 1, logvar = 0. Update 11/26/2022
    def __init__(self, x_dim, y_dim, hidden_size=None):
        # p_mu outputs mean of q(Y|X)
        # print("create CLUB with dim {}, {}, hiddensize {}".format(x_dim, y_dim, hidden_size))

        super(CLUBMean, self).__init__()

        if hidden_size is None:
            self.p_mu = nn.Linear(x_dim, y_dim)
        else:
            self.p_mu = nn.Sequential(nn.Linear(x_dim, int(hidden_size)),
                                       nn.ReLU(),
                                       nn.Linear(int(hidden_size), y_dim))


    def get_mu_logvar(self, x_samples):
        # variance is set to 1, which means logvar=0
        mu = self.p_mu(x_samples)
        return mu, 0

    def forward(self, x_samples, y_samples):

        mu, logvar = self.get_mu_logvar(x_samples)

        # log of conditional probability of positive sample pairs
        positive = - (mu - y_samples)**2 /2.

        prediction_1 = mu.unsqueeze(1)          # shape [nsample,1,dim]
        y_samples_1 = y_samples.unsqueeze(0)    # shape [1,nsample,dim]

        # log of conditional probability of negative sample pairs
        negative = - ((y_samples_1 - prediction_1)**2).mean(dim=1)/2.

        return (positive.sum(dim = -1) - negative.sum(dim = -1)).mean()

    def loglikeli(self, x_samples, y_samples): # unnormalized loglikelihood
        mu, logvar = self.get_mu_logvar(x_samples)
        return (-(mu - y_samples)**2).sum(dim=1).mean(dim=0)

    def learning_loss(self, x_samples, y_samples):
        return - self.loglikeli(x_samples, y_samples)





class CLUBSample(nn.Module):  # Sampled version of the CLUB estimator
    def __init__(self, x_dim, y_dim, hidden_size):
        super(CLUBSample, self).__init__()
        self.p_mu = nn.Sequential(nn.Linear(x_dim, hidden_size//2),
                                       nn.ReLU(),
                                       nn.Linear(hidden_size//2, y_dim))

        self.p_logvar = nn.Sequential(nn.Linear(x_dim, hidden_size//2),
                                       nn.ReLU(),
                                       nn.Linear(hidden_size//2, y_dim),
                                       nn.Tanh())

    def get_mu_logvar(self, x_samples):
        mu = self.p_mu(x_samples)
        logvar = self.p_logvar(x_samples)
        return mu, logvar


    def loglikeli(self, x_samples, y_samples):
        mu, logvar = self.get_mu_logvar(x_samples)
        return (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0)


    def forward(self, x_samples, y_samples):
        mu, logvar = self.get_mu_logvar(x_samples)

        sample_size = x_samples.shape[0]
        #random_index = torch.randint(sample_size, (sample_size,)).long()
        random_index = torch.randperm(sample_size).long()

        positive = - (mu - y_samples)**2 / logvar.exp()
        negative = - (mu - y_samples[random_index])**2 / logvar.exp()
        upper_bound = (positive.sum(dim = -1) - negative.sum(dim = -1)).mean()
        return upper_bound/2.

    def learning_loss(self, x_samples, y_samples):
        return - self.loglikeli(x_samples, y_samples)



In [None]:
class GaussianSampler(nn.Module):
    def __init__(self, dim, para_list = None):
        super(GaussianSampler, self).__init__()
        self.dim = dim
        if para_list is None:
            para_list = [0.55] * dim
        self.p_theta_ = torch.nn.Parameter(torch.tensor(para_list, requires_grad = True))

    def get_trans_mat(self):
        p_theta = self.p_theta_.cuda().unsqueeze(-1)
        #p_theta = torch.softmax(p_theta, dim = 0)

        trans_row1 = torch.cat((torch.sin(p_theta),torch.cos(p_theta)), dim=-1).unsqueeze(-1)
        trans_row2 = torch.cat((torch.cos(p_theta),torch.sin(p_theta)), dim=-1).unsqueeze(-1)  #[dim, 2,1]
        return torch.cat((trans_row1, trans_row2), dim=-1)  #[dim,2,2]

    def gen_samples(self, num_sample, cuda = True):
        noise= torch.randn(self.dim,num_sample,2).cuda()
        trans_mat = self.get_trans_mat()
        samples = torch.bmm(noise, trans_mat).transpose(0,1) #[dim, nsample, 2]
        if not cuda:
            samples = samples.cpu().detach().numpy()
        return samples[:,:,0], samples[:,:,1]

    def get_covariance(self):
        p_theta = self.p_theta_.cuda()
        return (2.*torch.sin(p_theta)*torch.cos(p_theta))

    def get_MI(self):
        rho = self.get_covariance()
        return -1./2.*torch.log(1-rho**2).sum().item()
        #return -self.dim /2.*torch.log(1-rho**2 / 2).sum().item()

In [None]:
lr = 1e-4
batch_size = 100
num_iter = 5000
sample_dim = 2
hidden_size = 5
estimator_name = "CLUB"


sampler = GaussianSampler(sample_dim).cuda()
#print("The corvariance of Gaussian is {}".format(sampler.get_covariance().cpu().detach().numpy()))
x_sample, y_sample = sampler.gen_samples(1000, cuda = False)
plt.scatter(x_sample, y_sample)
plt.show()

mi_estimator = eval(estimator_name)(sample_dim, sample_dim, hidden_size).cuda()

sampler_optimizer = torch.optim.Adam(sampler.parameters(), lr = lr)
mi_optimizer = torch.optim.Adam(mi_estimator.parameters(), lr = lr)

mi_true_values = []
mi_est_values = []
mi_loss_all = []
sampler_loss_all = []
min_mi = 100.

for i in range(num_iter):
    sampler.train()
    mi_estimator.eval()
    x_samples, y_samples = sampler.gen_samples(batch_size)
    sampler_loss = mi_estimator(x_samples, y_samples)
    sampler_optimizer.zero_grad()
    sampler_loss.backward() # retain_graph=True)
    sampler_optimizer.step()

    mi_losses = 0
    for j in range(5):
        mi_estimator.train()
        x_samples, y_samples = sampler.gen_samples(batch_size)
        mi_loss = mi_estimator.learning_loss(x_samples, y_samples)
        # mi_losses += mi_loss.item()
        mi_optimizer.zero_grad()
        mi_loss.backward()
        mi_optimizer.step()
        mi_losses += mi_loss.item()
        # print(mi_loss.item())

    mi_true_values.append(sampler.get_MI())
    mi_est_values.append(mi_estimator(x_samples, y_samples).item())
    mi_loss_all.append(mi_losses / 5)
    sampler_loss_all.append(sampler_loss.item())
    if i % 100 ==0:
        print("step {}, true MI value {}, estimated MI value {}, MI loss {}, sampler loss {}".format(i, sampler.get_MI(), mi_estimator(x_samples, y_samples).item(), mi_losses / 5, sampler_loss.item()))


plt.plot(np.arange(len(mi_est_values)), mi_est_values, label=estimator_name + " est")
plt.plot(np.arange(len(mi_true_values)), mi_true_values, label="True MI value")
plt.legend()
plt.show()

plt.plot(np.arange(len(mi_est_values)), sampler_loss_all)
plt.show()

plt.plot(np.arange(len(mi_est_values)), mi_loss_all)
plt.show()

x_sample, y_sample = sampler.gen_samples(1000, cuda=False)
plt.scatter(x_sample, y_sample)
plt.show()

# test

In [None]:
class LogisticRegression(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim, dropout=0):
        super().__init__()
        print(f'Logistic Regression classifier of dim ({in_dim} {hid_dim} {out_dim})')

        self.nn = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(in_dim, hid_dim, bias=True),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(hid_dim, out_dim, bias=True),
        )

    def forward(self, x, return_feat=False):
        out = self.nn(x)
        if return_feat:
            return out, x
        return out


class BertClassifier(nn.Module):
    FEAT_LEN = 768

    def __init__(self, raw_bert, classifier):
        super().__init__()
        self.bert = raw_bert
        self.fc = classifier

    def forward(self, x, return_feat=False):
        # x is a tokenized input
        # feature = self.bert(input_ids=x[0], token_type_ids=x[1], attention_mask=x[2])
        feature = self.bert(input_ids=x[0], attention_mask=x[2])
        # print(feature.last_hidden_state.shape)
        # out = self.fc(feature.pooler_output.flatten(1))       # not good for our task     # (BS, E)
        out = self.fc(feature.last_hidden_state.flatten(1))  # (BS, T, E)
        if return_feat:
            return out, feature.last_hidden_state, feature
        return out


def load_model_dic(model, ckpt_path, verbose=True, strict=True):
    """
    Load weights to model and take care of weight parallelism
    """
    assert os.path.exists(ckpt_path), f"trained model {ckpt_path} does not exist"

    try:
        model.load_state_dict(torch.load(ckpt_path), strict=strict)
    except:
        state_dict = torch.load(ckpt_path)
        state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()}
        model.load_state_dict(state_dict, strict=strict)
    if verbose:
        print(f'Model loaded: {ckpt_path}')

    return model


def save_model(ckpt_dir, cp_name, model):
    """
    Create directory /Checkpoint under exp_data_path and save encoder as cp_name
    """
    os.makedirs(ckpt_dir, exist_ok=True)
    saving_model_path = os.path.join(ckpt_dir, cp_name)
    if isinstance(model, torch.nn.DataParallel):
        model = model.module  # convert to non-parallel form
    torch.save(model.state_dict(), saving_model_path)
    print(f'Model saved: {saving_model_path}')


class StyleEncoder():
    def __init__(
            self,
            checkpoint,
            train_data,
            test_data):

        self.checkpoint = checkpoint
        self.train_data = train_data
        self.test_data = test_data

        # self.model = BertModel.from_pretrained(checkpoint)
        num_tokens, hidden_dim, out_dim, dropout = 256, 512, 100, 0.35
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased', padding=True, truncation=True)
        extractor = BertModel.from_pretrained('bert-base-cased')
        model = BertClassifier(extractor, LogisticRegression(768 * num_tokens, hidden_dim, out_dim, dropout=dropout))
        self.model = load_model_dic(model, checkpoint, verbose=True, strict=True)
        # self.tokenizer = BertTokenizer.from_pretrained(checkpoint, padding=True, truncation=True)
        # self.model = BertModel.from_pretrained(checkpoint)
        self.parameters = self.model.parameters()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.max_length = 256

    def prepare_inputs(self, examples):
        tokenizer = self.tokenizer
        inputs = tokenizer.encode_plus(examples['sentence1'], padding="max_length", truncation=True, max_length=self.max_length)
        return inputs

    def prepare_data(self, data):
        processed_data = data.map(self.prepare_inputs)
        processed_data = processed_data.remove_columns(["sentence1", "label"])
        processed_data.set_format(type="torch",)
        return processed_data

    def get_style(self, encoded_dict):
        data_input_ids = encoded_dict['input_ids'].to(self.device)
        data_attention_mask = encoded_dict['attention_mask'].to(self.device)
        data_token_type_ids = encoded_dict['token_type_ids'].to(self.device)
        x = data_input_ids, data_token_type_ids, data_attention_mask
        with torch.no_grad():
            # model_output = self.model(input_ids=data_input_ids, attention_mask=data_attention_mask)
            pred, feats, model_output = self.model(x, return_feat=True)
        # sentence_embedding = self.mean_pooling(model_output, encoded_dict['attention_mask'])
        # last_hidden_states = model_output.last_hidden_state
        # return last_hidden_states
        return feats


class AverageMeter(object):
    """
    Computes and stores the average and current value
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


from math import log

class SupConLoss(nn.Module):
    def __init__(self, temperature=0.1, margin=0.2):
        """
        Implementation of the loss described in the paper Supervised Contrastive Learning :
        https://arxiv.org/abs/2004.11362
        :param temperature: int
        """
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.cos = nn.CosineSimilarity(dim=-1)
        self.margin = margin

    def forward(self, projections, targets):
        """
        :param projections: torch.Tensor, shape [batch_size, projection_dim]
        :param targets: torch.Tensor, shape [batch_size]
        :return: torch.Tensor, scalar
        """
        device = torch.device("cuda") if projections.is_cuda else torch.device("cpu")

        # Compute similarity matrix
        # dot_product_tempered = torch.mm(projections, projections.T) / self.temperature
        # The cosine similarity between all pairs of projection vectors is computed
        dot_product_tempered = self.cos(projections.unsqueeze(1), projections.unsqueeze(0)) / self.temperature

        # Compute softmax probabilities over all pairs (positive and negative)
        # Minus max for numerical stability with exponential. Same done in cross entropy. Epsilon added to avoid log(0)
        exp_dot_tempered = (
            torch.exp(dot_product_tempered - torch.max(dot_product_tempered, dim=1, keepdim=True)[0]) + 1e-5
        )
        # Identify positive pairs for each anchor sample
        # This mask identifies pairs of samples that belong to the same class
        mask_similar_class = (targets.unsqueeze(1).repeat(1, targets.shape[0]) == targets).to(device)
        # This mask removes the self-similarity (diagonal elements)
        mask_anchor_out = (1 - torch.eye(exp_dot_tempered.shape[0])).to(device)
        # This is the combined mask that identifies positive pairs (i.e., samples that belong to the same class but are not the same sample)
        mask_combined_pos = mask_similar_class * mask_anchor_out

        mask_diff_class = (targets.unsqueeze(1).repeat(1, targets.shape[0]) != targets).to(device)
        mask_combined_neg = mask_diff_class * mask_anchor_out

        # exp_sum = torch.sum(exp_dot_tempered * mask_anchor_out, dim=1, keepdim=True)
        # probabilities = exp_dot_tempered / (exp_sum + 1e-5)

        # Compute number of relevant positive samples for each anchor sample
        cardinality_pos = torch.sum(mask_combined_pos, dim=1)

        # to avoid nan value of the loss if there is only one sample of a category  on the batch
        # Ensures that if there's only one sample of a class (i.e., no positive pairs), the division by zero is avoided by setting the count to 1
        for i in range(cardinality_pos.size(0)):
            if cardinality_pos[i]==0:
                cardinality_pos[i] = 1

        # # Compute log probability of positive pairs
        # log_prob = -torch.log(exp_dot_tempered / (torch.sum(exp_dot_tempered * mask_anchor_out, dim=1, keepdim=True)))
        # supervised_contrastive_loss_per_sample = torch.sum(log_prob * mask_combined_pos, dim=1) / cardinality_pos
        # supervised_contrastive_loss = torch.mean(supervised_contrastive_loss_per_sample)

        # Sum of the exponentiated similarities for the negative pairs
        exp_sum_neg = torch.sum(exp_dot_tempered * mask_combined_neg, dim=1)
        prob = exp_dot_tempered / (exp_dot_tempered + exp_sum_neg + 1e-5)

        log_prob = -torch.log(prob) * mask_combined_pos
        for i in range(cardinality_pos.size(0)):
            if cardinality_pos[i]==0:
                cardinality_pos[i] = 1

        total_loss = torch.mean(torch.sum(log_prob, dim=1) / cardinality_pos)

        return total_loss


class BertDataset(Dataset):
    def __init__(self, x, y, tokenizer, length=128, return_idx=False):
        super(BertDataset, self).__init__()
        self.tokenizer = tokenizer
        self.length = length
        self.x = x
        self.return_idx = return_idx
        self.y = torch.tensor(y)
        self.tokens_cache = {}

    def tokenize(self, x):
        dic = self.tokenizer.batch_encode_plus(
            [x],  # input must be a list
            max_length=self.length,
            padding='max_length',
            truncation=True,
            return_token_type_ids=True,
            return_tensors="pt"
        )
        return [x[0] for x in dic.values()]  # get rid of the first dim

    def __getitem__(self, idx):
        int_idx = int(idx)
        assert idx == int_idx
        idx = int_idx
        if idx not in self.tokens_cache:
            self.tokens_cache[idx] = self.tokenize(self.x[idx])
        input_ids, token_type_ids, attention_mask = self.tokens_cache[idx]
        if self.return_idx:
            return input_ids, token_type_ids, attention_mask, self.y[idx], idx, self.x[idx]
        return input_ids, token_type_ids, attention_mask, self.y[idx]

    def __len__(self):
        return len(self.y)


def get_csv_dataset(data_file):
    dataset = load_dataset("csv", data_files=data_file)
    return dataset


In [None]:
import pandas as pd

nlp_train = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/processed/train_random100_label_1.csv')
nlp_test = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/processed/test_random100_label_1.csv')
nlp_train = nlp_train[['prompt', 'user_name']]
nlp_train.columns = ['content', 'Target']
nlp_test = nlp_test[['prompt', 'user_name']]
nlp_test.columns = ['content', 'Target']


limit = 100
print("Number of authors: ", limit)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

style_checkpoint = "/content/drive/MyDrive/msc_project/model/contrastive/club/style_encoder.pt"

num_tokens, hidden_dim, out_dim, dropout = 256, 512, 100, 0.35
ngpus, base_bs = 1, 32
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', padding=True, truncation=True)
extractor = BertModel.from_pretrained('bert-base-cased')
model = BertClassifier(extractor, LogisticRegression(768 * num_tokens, hidden_dim, out_dim, dropout=dropout))
model = load_model_dic(model, style_checkpoint, verbose=True, strict=True)

model.to(device)

# get dataset
test_x, test_y = nlp_test['content'].tolist(), nlp_test['Target'].tolist()
test_set = BertDataset(test_x, test_y, tokenizer, num_tokens)
test_loader = DataLoader(test_set, batch_size=base_bs * ngpus, shuffle=False, num_workers=4 * ngpus,
                          pin_memory=True)


criterion = nn.CrossEntropyLoss()
supcon = SupConLoss()
coefficient = 1

model.eval()
pg = tqdm(test_loader, leave=False, total=len(test_loader), disable=False)

# Initialize lists to store predictions and labels
all_preds = []
all_labels = []

with torch.no_grad():
    test_acc = AverageMeter()
    test_loss_1 = AverageMeter()
    test_loss_2 = AverageMeter()
    test_loss = AverageMeter()

    for i, (x1, x2, x3, y) in enumerate(pg):
        x, y = (x1.cuda(), x2.cuda(), x3.cuda()), y.cuda()
        pred, feats, model_output = model(x, return_feat=True)

        # classification
        loss_1 = criterion(pred, y.long())

        # contrastive learning
        # sim_matrix = compute_sim_matrix(feats)
        # target_matrix = compute_target_matrix(y)
        # loss_2 = contrastive_loss(sim_matrix, target_matrix, temperature, y)
        loss_2 = supcon(pred, y.long())

        # total loss
        loss = loss_1 + coefficient * loss_2
        # loss = loss_2
        # loss = loss_1
        # loss_2 = loss_1

        # logger
        test_acc.update((pred.argmax(1) == y).sum().item() / len(y))
        test_loss.update(loss.item())
        test_loss_1.update(loss_1.item())
        test_loss_2.update(loss_2.item())

        # Append the predictions and labels to the lists
        all_preds.extend(pred.argmax(1).cpu().numpy())
        all_labels.extend(y.cpu().numpy())

        pg.set_postfix({
            'test acc': '{:.6f}'.format(test_acc.avg),
            # 'epoch': '{:03d}'.format(epoch)
        })
        print(pred.argmax(1), y, test_acc.avg)

# After the loop, you can now use all_preds and all_labels as needed
print("All Predictions:", all_preds)
print("All Labels:", all_labels)
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
accuracy = np.mean(all_preds == all_labels)
print("Accuracy:", accuracy)

print(test_acc.avg)
print(test_loss.avg)
print(test_loss_1.avg)
print(test_loss_2.avg)

In [None]:
import pandas as pd

nlp_train = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/processed/train_random100_label_1.csv')
nlp_test = pd.read_csv('/content/drive/MyDrive/msc_project/data/diffusiondb/processed/test_random100_label_1.csv')
nlp_train = nlp_train[['prompt', 'user_name']]
nlp_train.columns = ['content', 'Target']
nlp_test = nlp_test[['prompt', 'user_name']]
nlp_test.columns = ['content', 'Target']
print(len(nlp_test))
prompt_counts = nlp_test.groupby('Target').size()

print(prompt_counts)

limit = 100
print("Number of authors: ", limit)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

style_checkpoint = '/content/drive/MyDrive/msc_project/model/contrastive/contrax/exp_data/diffusiondb100_supcon_cls_bert-base-cased_coe1_temp0.1_unit6_epoch30/diffusiondb100_supcon_cls_val0.79216_e24.pt'


num_tokens, hidden_dim, out_dim, dropout = 256, 512, 100, 0.35
ngpus, base_bs = 1, 32
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', padding=True, truncation=True)
extractor = BertModel.from_pretrained('bert-base-cased')
model = BertClassifier(extractor, LogisticRegression(768 * num_tokens, hidden_dim, out_dim, dropout=dropout))
model = load_model_dic(model, style_checkpoint, verbose=True, strict=True)

model.to(device)

# get dataset
test_x, test_y = nlp_test['content'].tolist(), nlp_test['Target'].tolist()
test_set = BertDataset(test_x, test_y, tokenizer, num_tokens)
test_loader = DataLoader(test_set, batch_size=base_bs * ngpus, shuffle=False, num_workers=4 * ngpus,
                          pin_memory=True)


criterion = nn.CrossEntropyLoss()
supcon = SupConLoss()
coefficient = 1

model.eval()
pg = tqdm(test_loader, leave=False, total=len(test_loader), disable=False)

# Initialize lists to store predictions and labels
all_preds = []
all_labels = []

with torch.no_grad():
    test_acc = AverageMeter()
    test_loss_1 = AverageMeter()
    test_loss_2 = AverageMeter()
    test_loss = AverageMeter()

    for i, (x1, x2, x3, y) in enumerate(pg):
        x, y = (x1.cuda(), x2.cuda(), x3.cuda()), y.cuda()
        pred, feats, model_output = model(x, return_feat=True)

        # classification
        loss_1 = criterion(pred, y.long())

        # contrastive learning
        # sim_matrix = compute_sim_matrix(feats)
        # target_matrix = compute_target_matrix(y)
        # loss_2 = contrastive_loss(sim_matrix, target_matrix, temperature, y)
        loss_2 = supcon(pred, y.long())

        # total loss
        loss = loss_1 + coefficient * loss_2
        # loss = loss_2
        # loss = loss_1
        # loss_2 = loss_1

        # logger
        test_acc.update((pred.argmax(1) == y).sum().item() / len(y))
        test_loss.update(loss.item())
        test_loss_1.update(loss_1.item())
        test_loss_2.update(loss_2.item())

        # Append the predictions and labels to the lists
        all_preds.extend(pred.argmax(1).cpu().numpy())
        all_labels.extend(y.cpu().numpy())

        pg.set_postfix({
            'test acc': '{:.6f}'.format(test_acc.avg),
            # 'epoch': '{:03d}'.format(epoch)
        })
        print(pred.argmax(1), y, test_acc.avg)

# After the loop, you can now use all_preds and all_labels as needed
print("All Predictions:", all_preds)
print("All Labels:", all_labels)
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
accuracy = np.mean(all_preds == all_labels)
print("Accuracy:", accuracy)

print(test_acc.avg)
print(test_loss.avg)
print(test_loss_1.avg)
print(test_loss_2.avg)

In [None]:
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
accuracy = np.mean(all_preds == all_labels)
print("Accuracy:", accuracy)

In [None]:
all_preds1 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 54, 0, 0, 0, 1, 1, 1, 1, 1, 1, 51, 1, 1, 95, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 97, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 76, 5, 5, 5, 76, 5, 5, 5, 5, 5, 5, 5, 6, 6, 43, 6, 84, 6, 6, 14, 6, 55, 28, 6, 6, 74, 6, 6, 6, 74, 6, 6, 7, 37, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 44, 54, 8, 8, 8, 35, 8, 8, 8, 8, 8, 15, 8, 36, 8, 8, 8, 8, 16, 8, 64, 9, 9, 9, 59, 9, 56, 9, 9, 9, 9, 9, 79, 9, 9, 9, 9, 9, 9, 9, 10, 10, 80, 10, 40, 10, 10, 10, 97, 10, 10, 10, 10, 10, 54, 10, 10, 10, 10, 10, 11, 11, 29, 11, 11, 11, 11, 11, 11, 11, 11, 11, 51, 36, 11, 11, 1, 11, 52, 67, 12, 46, 12, 12, 12, 12, 12, 12, 12, 12, 97, 12, 12, 12, 1, 12, 12, 78, 12, 12, 13, 75, 15, 31, 13, 1, 84, 14, 13, 61, 13, 13, 13, 13, 13, 13, 13, 73, 13, 13, 95, 3, 14, 14, 14, 56, 14, 14, 88, 11, 14, 14, 54, 14, 14, 14, 14, 14, 14, 14, 35, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 16, 62, 16, 16, 16, 16, 73, 43, 16, 16, 16, 17, 22, 16, 16, 16, 84, 16, 61, 16, 17, 17, 17, 17, 17, 17, 17, 17, 95, 17, 17, 17, 46, 43, 29, 17, 17, 17, 17, 94, 18, 28, 18, 18, 18, 18, 74, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 28, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 74, 19, 19, 19, 19, 19, 19, 19, 70, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 1, 21, 8, 87, 79, 21, 21, 21, 21, 21, 80, 21, 21, 21, 21, 21, 74, 21, 21, 43, 88, 14, 67, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 49, 22, 22, 29, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 46, 15, 24, 24, 24, 24, 62, 24, 24, 33, 19, 24, 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 13, 20, 26, 26, 26, 26, 26, 26, 26, 96, 26, 26, 26, 26, 26, 26, 0, 57, 26, 94, 27, 27, 27, 27, 27, 83, 18, 7, 27, 27, 64, 27, 27, 27, 99, 27, 27, 27, 27, 27, 28, 49, 46, 28, 28, 28, 28, 28, 85, 28, 28, 28, 28, 28, 28, 28, 39, 28, 28, 28, 60, 22, 21, 8, 29, 37, 29, 70, 80, 83, 29, 29, 29, 85, 29, 29, 29, 87, 29, 56, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 13, 31, 31, 31, 13, 32, 51, 9, 32, 32, 32, 32, 32, 97, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 61, 33, 54, 43, 33, 80, 33, 33, 33, 33, 33, 56, 33, 33, 54, 33, 33, 14, 33, 34, 40, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 61, 34, 1, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 17, 35, 35, 35, 84, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 66, 37, 31, 33, 37, 36, 37, 37, 37, 37, 84, 37, 37, 37, 37, 14, 37, 44, 37, 87, 37, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 40, 80, 81, 6, 80, 9, 34, 40, 26, 40, 40, 40, 19, 77, 40, 40, 1, 21, 40, 46, 41, 41, 41, 92, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 61, 41, 41, 42, 42, 42, 42, 84, 48, 84, 6, 42, 11, 24, 94, 42, 42, 42, 62, 42, 81, 42, 42, 43, 43, 43, 21, 43, 54, 43, 46, 43, 44, 95, 43, 43, 43, 43, 1, 43, 43, 43, 93, 44, 44, 44, 60, 44, 44, 44, 44, 74, 14, 44, 43, 43, 43, 44, 44, 44, 44, 44, 44, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 16, 34, 46, 84, 46, 46, 88, 46, 46, 52, 31, 80, 46, 46, 46, 43, 72, 90, 46, 66, 46, 47, 47, 33, 47, 97, 47, 47, 47, 47, 79, 47, 47, 47, 47, 47, 47, 47, 79, 47, 47, 48, 48, 48, 48, 6, 48, 55, 48, 48, 48, 48, 48, 48, 48, 48, 48, 32, 48, 48, 27, 49, 49, 48, 49, 49, 46, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 50, 77, 50, 50, 50, 50, 53, 50, 50, 50, 50, 47, 50, 50, 50, 6, 50, 50, 50, 50, 51, 84, 51, 51, 51, 51, 51, 51, 51, 51, 74, 14, 51, 51, 51, 51, 51, 51, 51, 51, 52, 52, 97, 52, 52, 52, 52, 52, 43, 16, 52, 52, 52, 80, 52, 52, 52, 52, 52, 52, 32, 63, 53, 53, 53, 73, 53, 44, 97, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 54, 54, 40, 54, 54, 35, 54, 54, 54, 89, 54, 54, 54, 13, 54, 11, 33, 54, 95, 66, 55, 55, 55, 55, 85, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 56, 56, 16, 9, 56, 14, 64, 56, 95, 56, 0, 56, 28, 95, 84, 56, 56, 8, 56, 56, 57, 57, 57, 57, 57, 57, 57, 57, 5, 57, 57, 57, 53, 57, 57, 57, 44, 88, 57, 57, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 59, 32, 20, 59, 20, 59, 72, 59, 24, 90, 59, 59, 18, 59, 59, 59, 26, 59, 73, 59, 60, 60, 60, 31, 73, 80, 43, 60, 60, 82, 60, 34, 60, 60, 60, 60, 60, 65, 60, 74, 44, 44, 43, 61, 61, 61, 61, 61, 61, 32, 61, 61, 61, 56, 61, 61, 88, 61, 73, 74, 62, 62, 62, 62, 62, 62, 62, 64, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 21, 62, 63, 63, 63, 63, 63, 63, 63, 74, 63, 63, 63, 63, 63, 56, 37, 51, 63, 63, 63, 64, 64, 47, 64, 64, 64, 90, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 1, 64, 64, 64, 65, 65, 65, 65, 65, 65, 95, 65, 65, 6, 61, 35, 61, 65, 80, 51, 65, 65, 65, 65, 66, 28, 66, 66, 26, 8, 66, 66, 29, 66, 66, 28, 28, 66, 66, 34, 66, 66, 66, 66, 67, 82, 67, 97, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 88, 67, 67, 67, 67, 68, 68, 13, 68, 68, 68, 68, 68, 68, 68, 68, 94, 68, 68, 68, 68, 68, 68, 68, 68, 69, 3, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 70, 88, 70, 44, 70, 70, 31, 70, 70, 70, 70, 70, 70, 70, 70, 70, 7, 70, 70, 70, 71, 71, 71, 71, 71, 71, 71, 45, 66, 8, 29, 71, 71, 71, 71, 88, 71, 71, 71, 71, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 7, 72, 72, 72, 72, 72, 59, 72, 72, 73, 73, 61, 73, 16, 73, 73, 50, 73, 73, 73, 73, 33, 73, 73, 73, 19, 73, 66, 73, 74, 74, 74, 16, 74, 74, 74, 74, 74, 74, 74, 74, 74, 60, 6, 15, 74, 13, 74, 36, 75, 75, 75, 24, 75, 75, 6, 75, 54, 75, 75, 84, 75, 75, 60, 75, 75, 75, 75, 75, 76, 76, 76, 76, 76, 76, 76, 76, 76, 52, 76, 76, 76, 24, 76, 76, 14, 76, 76, 76, 77, 77, 77, 77, 77, 77, 12, 77, 45, 77, 77, 56, 77, 95, 8, 11, 56, 77, 77, 77, 78, 78, 78, 44, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 17, 79, 79, 79, 79, 29, 80, 80, 34, 34, 61, 80, 80, 95, 80, 67, 98, 80, 80, 1, 80, 80, 80, 80, 80, 81, 40, 81, 8, 81, 81, 81, 81, 81, 81, 29, 81, 19, 18, 29, 81, 81, 81, 81, 81, 84, 82, 82, 82, 82, 82, 82, 42, 27, 82, 82, 82, 82, 94, 82, 82, 82, 82, 82, 82, 97, 14, 83, 83, 32, 83, 16, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 77, 84, 84, 44, 22, 84, 84, 84, 84, 43, 21, 14, 67, 61, 84, 84, 43, 84, 84, 84, 85, 85, 85, 75, 85, 85, 85, 85, 85, 85, 57, 85, 85, 73, 37, 85, 85, 85, 85, 85, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 87, 87, 87, 43, 87, 80, 46, 87, 87, 16, 88, 87, 87, 87, 95, 87, 87, 87, 40, 87, 88, 18, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 22, 89, 89, 48, 61, 89, 89, 89, 89, 89, 89, 97, 89, 89, 89, 89, 53, 89, 89, 89, 89, 90, 36, 37, 90, 90, 65, 90, 90, 73, 61, 90, 90, 90, 61, 85, 90, 90, 95, 90, 73, 91, 91, 91, 91, 91, 91, 91, 22, 91, 73, 29, 37, 91, 91, 91, 91, 91, 91, 40, 91, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 51, 92, 33, 92, 92, 26, 92, 92, 85, 93, 93, 20, 93, 93, 93, 6, 72, 93, 93, 34, 93, 75, 93, 53, 93, 11, 93, 93, 93, 94, 94, 55, 94, 27, 94, 94, 94, 94, 94, 64, 94, 94, 94, 28, 94, 19, 59, 94, 64, 95, 95, 95, 95, 95, 95, 95, 95, 6, 95, 95, 55, 95, 95, 43, 95, 95, 95, 95, 95, 96, 96, 96, 24, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 52, 96, 96, 96, 96, 96, 97, 97, 97, 97, 40, 97, 97, 97, 97, 97, 97, 97, 1, 74, 97, 97, 97, 84, 97, 97, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99]
all_preds2 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 54, 0, 0, 0, 1, 1, 1, 1, 1, 1, 51, 1, 1, 95, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 97, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 76, 5, 5, 5, 76, 5, 5, 5, 5, 5, 5, 5, 6, 6, 43, 6, 84, 6, 6, 14, 6, 55, 28, 6, 6, 74, 6, 6, 6, 74, 6, 6, 7, 37, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 44, 54, 8, 8, 8, 35, 8, 8, 8, 8, 8, 15, 8, 36, 8, 8, 8, 8, 16, 8, 64, 9, 9, 9, 59, 9, 56, 9, 9, 9, 9, 9, 79, 9, 9, 9, 9, 9, 9, 9, 10, 10, 80, 10, 40, 10, 10, 10, 97, 10, 10, 10, 10, 10, 54, 10, 10, 10, 10, 10, 11, 11, 29, 11, 11, 11, 11, 11, 11, 11, 11, 11, 51, 36, 11, 11, 1, 11, 52, 67, 12, 46, 12, 12, 12, 12, 12, 12, 12, 12, 97, 12, 12, 12, 1, 12, 12, 78, 12, 12, 13, 75, 15, 31, 13, 1, 84, 14, 13, 61, 13, 13, 13, 13, 13, 13, 13, 73, 13, 13, 95, 3, 14, 14, 14, 56, 14, 14, 88, 11, 14, 14, 54, 14, 14, 14, 14, 14, 14, 14, 35, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 16, 62, 16, 16, 16, 16, 73, 43, 16, 16, 16, 17, 22, 16, 16, 16, 84, 16, 61, 16, 17, 17, 17, 17, 17, 17, 17, 17, 95, 17, 17, 17, 46, 43, 29, 17, 17, 17, 17, 94, 18, 28, 18, 18, 18, 18, 74, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 28, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 74, 19, 19, 19, 19, 19, 19, 19, 70, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 1, 21, 8, 87, 79, 21, 21, 21, 21, 21, 80, 21, 21, 21, 21, 21, 74, 21, 21, 43, 88, 14, 67, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 49, 22, 22, 29, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 46, 15, 24, 24, 24, 24, 62, 24, 24, 33, 19, 24, 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 13, 20, 26, 26, 26, 26, 26, 26, 26, 96, 26, 26, 26, 26, 26, 26, 0, 57, 26, 94, 27, 27, 27, 27, 27, 83, 18, 7, 27, 27, 64, 27, 27, 27, 99, 27, 27, 27, 27, 27, 28, 49, 46, 28, 28, 28, 28, 28, 85, 28, 28, 28, 28, 28, 28, 28, 39, 28, 28, 28, 60, 22, 21, 8, 29, 37, 29, 70, 80, 83, 29, 29, 29, 85, 29, 29, 29, 87, 29, 56, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 13, 31, 31, 31, 13, 32, 51, 9, 32, 32, 32, 32, 32, 97, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 61, 33, 54, 43, 33, 80, 33, 33, 33, 33, 33, 56, 33, 33, 54, 33, 33, 14, 33, 34, 40, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 61, 34, 1, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 17, 35, 35, 35, 84, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 66, 37, 31, 33, 37, 36, 37, 37, 37, 37, 84, 37, 37, 37, 37, 14, 37, 44, 37, 87, 37, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 40, 80, 81, 6, 80, 9, 34, 40, 26, 40, 40, 40, 19, 77, 40, 40, 1, 21, 40, 46, 41, 41, 41, 92, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 61, 41, 41, 42, 42, 42, 42, 84, 48, 84, 6, 42, 11, 24, 94, 42, 42, 42, 62, 42, 81, 42, 42, 43, 43, 43, 21, 43, 54, 43, 46, 43, 44, 95, 43, 43, 43, 43, 1, 43, 43, 43, 93, 44, 44, 44, 60, 44, 44, 44, 44, 74, 14, 44, 43, 43, 43, 44, 44, 44, 44, 44, 44, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 16, 34, 46, 84, 46, 46, 88, 46, 46, 52, 31, 80, 46, 46, 46, 43, 72, 90, 46, 66, 46, 47, 47, 33, 47, 97, 47, 47, 47, 47, 79, 47, 47, 47, 47, 47, 47, 47, 79, 47, 47, 48, 48, 48, 48, 6, 48, 55, 48, 48, 48, 48, 48, 48, 48, 48, 48, 32, 48, 48, 27, 49, 49, 48, 49, 49, 46, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 50, 77, 50, 50, 50, 50, 53, 50, 50, 50, 50, 47, 50, 50, 50, 6, 50, 50, 50, 50, 51, 84, 51, 51, 51, 51, 51, 51, 51, 51, 74, 14, 51, 51, 51, 51, 51, 51, 51, 51, 52, 52, 97, 52, 52, 52, 52, 52, 43, 16, 52, 52, 52, 80, 52, 52, 52, 52, 52, 52, 32, 63, 53, 53, 53, 73, 53, 44, 97, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 53, 54, 54, 40, 54, 54, 35, 54, 54, 54, 89, 54, 54, 54, 13, 54, 11, 33, 54, 95, 66, 55, 55, 55, 55, 85, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 56, 56, 16, 9, 56, 14, 64, 56, 95, 56, 0, 56, 28, 95, 84, 56, 56, 8, 56, 56, 57, 57, 57, 57, 57, 57, 57, 57, 5, 57, 57, 57, 53, 57, 57, 57, 44, 88, 57, 57, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 58, 59, 32, 20, 59, 20, 59, 72, 59, 24, 90, 59, 59, 18, 59, 59, 59, 26, 59, 73, 59, 60, 60, 60, 31, 73, 80, 43, 60, 60, 82, 60, 34, 60, 60, 60, 60, 60, 65, 60, 74, 44, 44, 43, 61, 61, 61, 61, 61, 61, 32, 61, 61, 61, 56, 61, 61, 88, 61, 73, 74, 62, 62, 62, 62, 62, 62, 62, 64, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 21, 62, 63, 63, 63, 63, 63, 63, 63, 74, 63, 63, 63, 63, 63, 56, 37, 51, 63, 63, 63, 64, 64, 47, 64, 64, 64, 90, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 1, 64, 64, 64, 65, 65, 65, 65, 65, 65, 95, 65, 65, 6, 61, 35, 61, 65, 80, 51, 65, 65, 65, 65, 66, 28, 66, 66, 26, 8, 66, 66, 29, 66, 66, 28, 28, 66, 66, 34, 66, 66, 66, 66, 67, 82, 67, 97, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 88, 67, 67, 67, 67, 68, 68, 13, 68, 68, 68, 68, 68, 68, 68, 68, 94, 68, 68, 68, 68, 68, 68, 68, 68, 69, 3, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 69, 70, 88, 70, 44, 70, 70, 31, 70, 70, 70, 70, 70, 70, 70, 70, 70, 7, 70, 70, 70, 71, 71, 71, 71, 71, 71, 71, 45, 66, 8, 29, 71, 71, 71, 71, 88, 71, 71, 71, 71, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 7, 72, 72, 72, 72, 72, 59, 72, 72, 73, 73, 61, 73, 16, 73, 73, 50, 73, 73, 73, 73, 33, 73, 73, 73, 19, 73, 66, 73, 74, 74, 74, 16, 74, 74, 74, 74, 74, 74, 74, 74, 74, 60, 6, 15, 74, 13, 74, 36, 75, 75, 75, 24, 75, 75, 6, 75, 54, 75, 75, 84, 75, 75, 60, 75, 75, 75, 75, 75, 76, 76, 76, 76, 76, 76, 76, 76, 76, 52, 76, 76, 76, 24, 76, 76, 14, 76, 76, 76, 77, 77, 77, 77, 77, 77, 12, 77, 45, 77, 77, 56, 77, 95, 8, 11, 56, 77, 77, 77, 78, 78, 78, 44, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 78, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 17, 79, 79, 79, 79, 29, 80, 80, 34, 34, 61, 80, 80, 95, 80, 67, 98, 80, 80, 1, 80, 80, 80, 80, 80, 81, 40, 81, 8, 81, 81, 81, 81, 81, 81, 29, 81, 19, 18, 29, 81, 81, 81, 81, 81, 84, 82, 82, 82, 82, 82, 82, 42, 27, 82, 82, 82, 82, 94, 82, 82, 82, 82, 82, 82, 97, 14, 83, 83, 32, 83, 16, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 77, 84, 84, 44, 22, 84, 84, 84, 84, 43, 21, 14, 67, 61, 84, 84, 43, 84, 84, 84, 85, 85, 85, 75, 85, 85, 85, 85, 85, 85, 57, 85, 85, 73, 37, 85, 85, 85, 85, 85, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 86, 87, 87, 87, 43, 87, 80, 46, 87, 87, 16, 88, 87, 87, 87, 95, 87, 87, 87, 40, 87, 88, 18, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 88, 22, 89, 89, 48, 61, 89, 89, 89, 89, 89, 89, 97, 89, 89, 89, 89, 53, 89, 89, 89, 89, 90, 36, 37, 90, 90, 65, 90, 90, 73, 61, 90, 90, 90, 61, 85, 90, 90, 95, 90, 73, 91, 91, 91, 91, 91, 91, 91, 22, 91, 73, 29, 37, 91, 91, 91, 91, 91, 91, 40, 91, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 92, 51, 92, 33, 92, 92, 26, 92, 92, 85, 93, 93, 20, 93, 93, 93, 6, 72, 93, 93, 34, 93, 75, 93, 53, 93, 11, 93, 93, 93, 94, 94, 55, 94, 27, 94, 94, 94, 94, 94, 64, 94, 94, 94, 28, 94, 19, 59, 94, 64, 95, 95, 95, 95, 95, 95, 95, 95, 6, 95, 95, 55, 95, 95, 43, 95, 95, 95, 95, 95, 96, 96, 96, 24, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 52, 96, 96, 96, 96, 96, 97, 97, 97, 97, 40, 97, 97, 97, 97, 97, 97, 97, 1, 74, 97, 97, 97, 84, 97, 97, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 98, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99]
all_preds1 = np.array(all_preds1)
all_preds2 = np.array(all_preds2)
accuracy = np.mean(all_preds1 == all_preds2)
print("Accuracy:", accuracy)