In [1]:
#!pip install textattack[tensorflow,optional]
#!pip install -U datasets

In [2]:
import os
import datetime

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision
import tensorflow as tf
import transformers
from transformers import CLIPProcessor, CLIPModel

import textattack
from textattack.models.wrappers import HuggingFaceModelWrapper

  warn(
  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


### BERT-based-uncased

In [20]:
# bert-based-uncased
model_path1 = 'bert-base-uncased'

bert_model = transformers.AutoModel.from_pretrained(model_path1)
bert_tokenizer = transformers.AutoTokenizer.from_pretrained(model_path1)



In [23]:
# obtain BERT encoding

cnt = 0
path = 'dataset_embeddings/bertagnewstest.txt'      # path
ds = textattack.datasets.HuggingFaceDataset('ag_news', split='test')                               # train dataset
#path = 'dataset_embeddings/bertmrtest.txt'      # path
#ds = test_dataset                               # train dataset

st = datetime.datetime.now()

with open(path, 'w') as f:
    for text, label in ds:
        cnt += 1
        t = bert_tokenizer(text['text'], return_tensors="pt", padding=True, truncation=True)
        t = bert_model(**t).pooler_output.detach().numpy()
        
        np.savetxt(f, np.append(t, label).reshape(1,-1), delimiter=',')
        f.write('\n')
        
        if cnt%1000 == 0:
            print('.',end='')

et = datetime.datetime.now()
print(et-st)

textattack: Loading [94mdatasets[0m dataset [94mag_news[0m, split [94mtest[0m.


.......0:13:04.945165


### DistilBERT

In [73]:
# Distilbert-based-uncased
model_path2 = 'distilbert/distilbert-base-uncased'

disbert_model = transformers.AutoModel.from_pretrained(model_path2)
disbert_tokenizer = transformers.AutoTokenizer.from_pretrained(model_path2)



In [74]:
# obtain DistilBERT encoding

cnt = 0
path = 'dataset_embeddings/disbertagnewstrain.txt'      # path
ds = textattack.datasets.HuggingFaceDataset('ag_news', split='train')                               # train dataset
#path = 'dataset_embeddings/bertmrtest.txt'      # path
#ds = test_dataset                               # train dataset

st = datetime.datetime.now()

with open(path, 'w') as f:
    for text, label in ds:
        cnt += 1
        t = disbert_tokenizer(text['text'], return_tensors="pt", padding=True, truncation=True)
        t = disbert_model(**t).last_hidden_state.detach().numpy().mean(axis=1)
        
        np.savetxt(f, np.append(t, label).reshape(1,-1), delimiter=',')
        f.write('\n')
        
        if cnt%1000 == 0:
            print('.',end='')

et = datetime.datetime.now()
print(et-st)

textattack: Loading [94mdatasets[0m dataset [94mag_news[0m, split [94mtrain[0m.


........................................................................................................................1:47:00.250764


### CLIP

In [7]:
# Load CLIP model and processor
model_name = "openai/clip-vit-base-patch32"
clip_processor = CLIPProcessor.from_pretrained(model_name)
clip_model = CLIPModel.from_pretrained(model_name)

# Access text encoder
text_encoder = clip_model.text_model



In [81]:
# obtain clip encoding

cnt = 0
path = 'dataset_embeddings/clipagnewstest.txt'      # path
ds = textattack.datasets.HuggingFaceDataset('ag_news', split='test')                               # train dataset

st = datetime.datetime.now()

with open(path, 'w') as f:
    for text, label in ds:
        cnt += 1
        t = clip_processor(text['text'], return_tensors="pt", padding=True, truncation=True)
        t = text_encoder(**t).pooler_output.detach().numpy()
        
        np.savetxt(f, np.append(t, label).reshape(1,-1), delimiter=',')
        f.write('\n')
        
        if cnt%1000 == 0:
            print('.',end='')

et = datetime.datetime.now()
print(et-st)

textattack: Loading [94mdatasets[0m dataset [94mag_news[0m, split [94mtest[0m.


.......0:08:27.974111
