In [80]:
import os
import sys
import csv
import collections
import numpy as np
import gzip
import json

import torch
from torchtext.vocab import vocab
from torch.utils.data import Dataset
from torchtext.vocab import GloVe
from torchtext.data import get_tokenizer



In [19]:
import re
import string
import random
import pickle
import sys
from pprint import pprint

from tqdm import tqdm
import pandas as pd
from nltk.corpus import words
from nltk.tokenize import sent_tokenize

from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, RobertaTokenizerFast

In [20]:
"""
Main training script for Beer and Hotel.
"""

import argparse
from pathlib import Path
from time import time
import torch.nn as nn
import torch.nn.functional as F

import wandb

from rrtl.utils import (
    get_model_class,
    get_optimizer_class,
    args_factory,
    save_args,
    save_ckpt,
    load_ckpt,
    build_vib_path
)

from rrtl.logging_utils import log,log_cap
from rrtl.visualize import visualize_rationale,visualize_prob
from rrtl.stats import gold_rationale_capture_rate,cal_pns,first_capture
from rrtl.config import Config
import warnings

config = Config()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


#os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3'
warnings.filterwarnings('ignore')

In [21]:
def get_dataloader_class(args):
    if args.dataset_name in ('beer'):#, 'squad-addonesent', 'squad-addsent', 'squad-addonesent-pos0'):
        dataloader_class = SentimentDataLoader
    elif args.dataset_name == 'ga':
          dataloader_class = GADataLoader
    #    dataloader_class = SQUADNegRationaleDataLoader
    elif args.method=='a2r':
        dataloader_class = SentimentDataLoader
    else:
        raise ValueError('Dataloader not implemented.')
    return dataloader_class

In [53]:
def get_special_token_map(encoder_type):
    if encoder_type.startswith('roberta'):
        special_token_map = {
            'bos_token': '<s>',
            'eos_token': '</s>',
            'sep_token': '</s>',
            'cls_token': '<s>',
            'unk_token': '<unk>',
            'pad_token': '<pad>',
            'mask_token': '<mask>',
        }
    elif encoder_type.startswith('bert') or encoder_type.startswith('distilbert'):
        special_token_map = {
            'sep_token': '[SEP]',
            'cls_token': '[CLS]',
            'unk_token': '[UNK]',
            'pad_token': '[PAD]',
            'mask_token': '[MASK]',
        }
    return special_token_map


class BaseDataLoader:
    def __init__(self, args):
        self.args = args
        self.tok_kwargs = config.TOK_KWARGS
        self.tok_kwargs['max_length'] = self.args.max_length
        if self.args.dataset_name=='ga':
            with open('ga_code.pkl','rb') as f:
              self.tokenizer=pickle.load(f)
        elif self.args.encoder_type.startswith('bert') or self.args.encoder_type.startswith('distilbert'):
            self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', cache_dir=self.args.cache_dir)
        elif self.args.encoder_type.startswith('roberta'):
            self.tokenizer = RobertaTokenizerFast.from_pretrained(self.args.encoder_type, cache_dir=self.args.cache_dir)
        
        self.dataset_name_to_dataset_class = {
            'beer': SentimentDataset,
            'hotel': SentimentDataset
        }
        self._dataloaders = {}
        self.special_token_map = get_special_token_map(self.args.encoder_type)

    def _load_processed_data(self, mode):
        raise NotImplementedError

    def _build_dataloader(self, data, mode):
        dataset = self.dataset_name_to_dataset_class[self.args.dataset_name](
            self.args,
            data,
            self.tokenizer,
            self.tok_kwargs
        )
        collate_fn = dataset.collater
        batch_size = self.args.batch_size
        shuffle = True if mode == 'train' else False
        
        self._dataloaders[mode] = DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=collate_fn,
        )
        print(f'[{mode}] dataloader built => {len(dataset)} examples')
    
    def build(self, mode):
        data = self._load_raw_data(mode)
        self._build_dataloader(data, mode)

    def build_all(self):
        for mode in ['train', 'dev', 'test']:
            self.build(mode)
    
    def __getitem__(self, mode):
        return self._dataloaders[mode]

    @property
    def train(self):
        return self._dataloaders['train']

    @property
    def dev(self):
        return self._dataloaders['dev']
    
    @property
    def test(self):
        return self._dataloaders['test']


class BaseDataset(Dataset):
    def __init__(self, args, data, tokenizer, tok_kwargs):
        self.args = args
        self.data = data
        self.tokenizer = tokenizer
        self.tok_kwargs = tok_kwargs

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)
    
    @property
    def num_batches(self):
        return len(self.data) // self.args.batch_size


class SentimentDataLoader(BaseDataLoader):
    def __init__(self, args):
        super(SentimentDataLoader, self).__init__(args)
        if args.dataset_split == 'all':
            self.build_all()
        else:
            self.build(args.dataset_split)

    def _load_raw_data(self, mode):
            datapoints = []
            aspect=self.args.aspect
            #scale='normal'
            scale=self.args.scale
            print('aspect:',aspect)
            print('mode:',mode)
            print('scale:',scale)
            
            #scale='noise'
            
            if mode=='pns':
              path = config.DATA_DIR / f'sentiment/data/pns'
            else:
              if scale=='normal':
                if self.args.attack_path is not None:
                    path = self.args.attack_path
                elif self.args.dataset_name == 'beer' and mode in ('train', 'dev'):
                    path = config.DATA_DIR / f'sentiment/data/source/beer_{aspect}.{mode}'
                elif self.args.dataset_name == 'beer' and mode == 'test':
                    path = config.DATA_DIR / f'sentiment/data/target/beer_{aspect}.train'
                elif self.args.dataset_name == 'hotel' and mode in ('train', 'dev'):
                    path = config.DATA_DIR / f'sentiment/data/oracle/hotel_{aspect}.{mode}'
                elif self.args.dataset_name == 'hotel' and mode == 'test':
                    path = config.DATA_DIR / f'sentiment/data/target/hotel_{aspect}.train'
                else:
                    raise ValueError('Dataset name not supported.')
                    
              if scale=='small':
                if self.args.attack_path is not None:
                    path = self.args.attack_path
                elif self.args.dataset_name == 'beer' and mode in ('train', 'dev'):
                    path = config.DATA_DIR / f'sentiment/data/source/beer_{aspect}.{mode}_120'
                elif self.args.dataset_name == 'beer' and mode == 'test':
                    path = config.DATA_DIR / f'sentiment/data/target/beer_{aspect}.train_120'
                elif self.args.dataset_name == 'hotel' and mode in ('train', 'dev'):
                    path = config.DATA_DIR / f'sentiment/data/oracle/hotel_{aspect}.{mode}_120'
                elif self.args.dataset_name == 'hotel' and mode == 'test':
                    path = config.DATA_DIR / f'sentiment/data/target/hotel_{aspect}.train_120'
                else:
                    raise ValueError('Dataset name not supported.')
                    
              if scale=='noise':
                if self.args.dataset_name == 'beer' and mode in ('train', 'dev'):
                    path = config.DATA_DIR / f'sentiment/data/source/beer_{aspect}.{mode}_noise'
                elif self.args.dataset_name == 'beer' and mode == 'test':
                    path = config.DATA_DIR / f'sentiment/data/target/beer_{aspect}.train'
            
                    
              if scale=='causal':
                if self.args.dataset_name == 'beer' and mode in ('train', 'dev'):
                    path = config.DATA_DIR / f'sentiment/data/source/beer_{aspect}.{mode}_causal'
                elif self.args.dataset_name == 'beer' and mode == 'test':
                    path = config.DATA_DIR / f'sentiment/data/target/beer_{aspect}.train_causal' 
              
                
            df = pd.read_csv(path, delimiter='\t')
            for index, row in df.iterrows():
                label = row['label']

                # this could be applied to both beer and hotel
                if label >= 0.6:
                    label = 1  # pos
                elif label <= 0.4:
                    label = 0  # neg
                else:
                    continue
                text = row['text']
                if 'rationale' in row:
                    rationale = [int(r) for r in row['rationale'].split()]
                else:
                    rationale = [-1] * len(row['text'].split())
                datapoints.append({
                    'label': label,
                    'text': text,
                    'rationale': rationale,
                })
            if self.args.debug:
              datapoints = datapoints[:200]
            return datapoints

    def _load_processed_data(self, mode):
        processed_datapoints = []
        datapoints = self._load_raw_data(mode)
        for datapoint in tqdm(datapoints, total=len(datapoints)):
            label = datapoint['label']
            input_tokens = ['[CLS]'] + datapoint['text'].split()
            rationale = [0] + datapoint['rationale']
            input_ids = []
            attention_mask = []
            rationale_ = []
            for input_token, r in zip(input_tokens, rationale):
                tokenized = self.tokenizer.encode_plus(input_token, add_special_tokens=False)
                input_ids += tokenized['input_ids']
                attention_mask += tokenized['attention_mask']
                ## make rationale cover subword
                rationale_ += [r] * len(tokenized['input_ids'])
            #check length of sub-word toke
            #print(len(input_ids))
            if len(input_ids) >= self.args.max_length:
                input_ids = input_ids[:self.args.max_length - 1] + [102]
                attention_mask = attention_mask[:self.args.max_length - 1] + [1]
                rationale = rationale_[:self.args.max_length - 1] + [0]
            else:
                input_ids = input_ids + [102] #102 is [SEP]
                attention_mask = attention_mask + [1]
                rationale = rationale_ + [0]
                
            input_ids = self.pad(input_ids)
            attention_mask = self.pad(attention_mask)
            rationale = self.pad(rationale)

            assert len(input_ids) == self.args.max_length

            processed_datapoints.append({
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'label': label,
                'rationale': rationale,
            })
        return processed_datapoints

    def pad(self, seq):
        return seq + (self.args.max_length - len(seq)) * [0]


class SentimentDataset(BaseDataset):
    def __init__(self, args, data, tokenizer, tok_kwargs):
        super(SentimentDataset, self).__init__(args, data, tokenizer, tok_kwargs)

    def collater(self, batch):
        device = 'cuda' if self.args.use_cuda else 'cpu'
        return [datapoint for datapoint in batch]


      


In [54]:
    
        parser = argparse.ArgumentParser()
        

        # experiment
        parser.add_argument("--scale", type=str, default="normal", help="[small |normal]")
        parser.add_argument("--dataset-split", type=str, default="all", help="[all | train | dev | test]")
        parser.add_argument("--encoder-type", type=str, default="bert-base-uncased")
        parser.add_argument("--decoder-type", type=str, default="bert-base-uncased")
        parser.add_argument("--wandb", action="store_true")
        parser.add_argument("--debug", action="store_true")
        parser.add_argument("--cache_dir", type=str, default=config.CACHE_DIR)
        parser.add_argument("--overwrite_cache", action="store_true")
        parser.add_argument("--attack_path", type=str, default=None)

        # cuda
        parser.add_argument("--device_id", type=int, default=0)
        parser.add_argument("--dataparallel", action="store_true")
        parser.add_argument("--inspect-gpu", action="store_true")
        parser.add_argument("--disable-cuda", action="store_true")

        # printing, logging, and checkpointing
        parser.add_argument("--print-every", type=int, default=80)
        parser.add_argument("--eval-interval", type=int, default=500)
        parser.add_argument("--disable-ckpt", action="store_true")

        # training
        parser.add_argument("--train_opt", type=int, default=0)
        parser.add_argument("--warm_epoch", type=int, default=0)
        parser.add_argument("--seed", type=int, default=42)
        parser.add_argument("--nseeds", type=int, default=5)
        parser.add_argument("--batch_size", type=int, default=64)
        parser.add_argument("--max_length", type=int, default=300)
        parser.add_argument("--num_epoch", type=int, default=20)
        parser.add_argument("--lr", type=float, default=5e-5)
        parser.add_argument("--dropout_rate", type=float, default=0.2)
        parser.add_argument("--no-shuffle", action="store_true")
        parser.add_argument("--optimizer", type=str, default="adamw")
        parser.add_argument("--grad_accumulation_steps", type=int, default=1)

        # VIB model
        parser.add_argument("--k", type=int, default=3) # number of samples for PNS
        parser.add_argument("--mu", type=float, default=0.1) # weight for PNS
        #parser.add_argument("--alpha", type=float, default=0.0) #for concise loss
        parser.add_argument("--lambda1", type=float, default=0.0) #for concise loss
        parser.add_argument("--lambda2", type=float, default=0.0) #for continuity loss
        parser.add_argument("--tau", type=float, default=1.0)
        parser.add_argument("--pi", type=float, default=0.2)
        parser.add_argument("--beta", type=float, default=0.0)
        parser.add_argument("--gamma", type=float, default=1.0)
        parser.add_argument("--gamma2", type=float, default=1.0)
        parser.add_argument("--use-gold-rationale", action="store_true")
        parser.add_argument("--use-neg-rationale", action="store_true")
        parser.add_argument("--fix-input", type=str, default=None)

        # SPECTRA model
        parser.add_argument("--budget", type=int, default=None)
        parser.add_argument("--budget_ratio", type=float, default=None)
        parser.add_argument("--temperature", type=float, default=1.0)
        parser.add_argument("--solver_iter", type=int, default=100)


        

        args = parser.parse_args("")
        
        args.run_name='beer_rnp' 
        args.scale='small'
        args.dataset_name='beer' 
        args.model_type='rnp_beer_token'
        args.aspect='Palate'
        args.method = args.run_name.split('_')[1]
        
        args.method = args.run_name.split('_')[1]
        args.use_cuda=True
        
        args=args_factory(args)

In [55]:
dataloader_class = get_dataloader_class(args)
dl = dataloader_class(args)

aspect: Palate
mode: train
scale: small
[train] dataloader built => 9592 examples
aspect: Palate
mode: dev
scale: small
[dev] dataloader built => 2294 examples
aspect: Palate
mode: test
scale: small
[test] dataloader built => 200 examples


In [67]:
dl

<__main__.SentimentDataLoader at 0x7f7dc1662910>

In [71]:
for batch_idx, batch in enumerate(dl.train):
    sentence=batch[0]['text']
    words=sentence.strip().split()
    sys.exit()

SystemExit: 

In [91]:
from transformers import PreTrainedTokenizerFast
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_file="rrtl/tokenizer.json")

In [93]:
fast_tokenizer.tokenize("sslk wjsasa ssas")

['ssl', 'k', 'wj', 'sasa', 'ssas']

In [97]:
tokenizer("sslk wjsasa ssas")

['sslk', 'wjsasa', 'ssas']

In [99]:
tokenizer = get_tokenizer("basic_english")
global_vectors = GloVe(name='840B', dim=300)
global_vectors.get_vecs_by_tokens(tokenizer("sslk wjsasa ssas"), lower_case_backup=True)
global_vectors.stoi[tokenizer("sslk wjsasa ssas")[1]]

KeyError: 'wjsasa'

"clear dark red colored beer with a small tan head . smells of sweet malt and smoke , distinct toffee candy smell . slight coffee aroma but really this is more small candy maker caramel and less patent malt . starts out smoky and pleasant , tastes drier than it smells . keeps hitting the smoke and artisan caramel flavors . quite tasty and a smoked porter by taste . mouthfeel is light , blame the utah 4 % on tap only laws for brewing all their beers dry . still this is a mighty tasty `` light porter '' not really a style but an accurate description ."

In [83]:
tokenizer(sentence)

['clear',
 'dark',
 'red',
 'colored',
 'beer',
 'with',
 'a',
 'small',
 'tan',
 'head',
 '.',
 'smells',
 'of',
 'sweet',
 'malt',
 'and',
 'smoke',
 ',',
 'distinct',
 'toffee',
 'candy',
 'smell',
 '.',
 'slight',
 'coffee',
 'aroma',
 'but',
 'really',
 'this',
 'is',
 'more',
 'small',
 'candy',
 'maker',
 'caramel',
 'and',
 'less',
 'patent',
 'malt',
 '.',
 'starts',
 'out',
 'smoky',
 'and',
 'pleasant',
 ',',
 'tastes',
 'drier',
 'than',
 'it',
 'smells',
 '.',
 'keeps',
 'hitting',
 'the',
 'smoke',
 'and',
 'artisan',
 'caramel',
 'flavors',
 '.',
 'quite',
 'tasty',
 'and',
 'a',
 'smoked',
 'porter',
 'by',
 'taste',
 '.',
 'mouthfeel',
 'is',
 'light',
 ',',
 'blame',
 'the',
 'utah',
 '4',
 '%',
 'on',
 'tap',
 'only',
 'laws',
 'for',
 'brewing',
 'all',
 'their',
 'beers',
 'dry',
 '.',
 'still',
 'this',
 'is',
 'a',
 'mighty',
 'tasty',
 '``',
 'light',
 'porter',
 "'",
 "'",
 'not',
 'really',
 'a',
 'style',
 'but',
 'an',
 'accurate',
 'description',
 '.']

In [82]:
global_vectors.get_vecs_by_tokens(tokenizer(sentence), lower_case_backup=True)

tensor([[ 0.1146,  0.3763, -0.4446,  ...,  0.0250,  0.3489,  0.1445],
        [ 0.2370, -0.0127,  0.0258,  ..., -0.1554, -0.1004,  0.0244],
        [-0.2490, -0.2322, -0.0272,  ..., -0.3526,  0.4042,  0.1854],
        ...,
        [-0.6502,  0.3831, -0.6818,  ..., -0.2603,  0.3001,  0.4126],
        [-0.1262, -0.0680, -0.5887,  ..., -0.2915, -0.1349, -0.1408],
        [ 0.0120,  0.2075, -0.1258,  ...,  0.1387, -0.3605, -0.0350]])