In [5]:
from transformers import BertModel, BertTokenizer
from iit import IITModel
from torch_model_base import TorchModelBase


In [6]:
class HfBertClassifierModelIIT(IITModel):
    def __init__(self, n_classes,model, weights_name='bert-base-cased'):
        super().__init__()
        self.n_classes = n_classes
        self.weights_name = weights_name
        self.bert = model
        self.bert.train()
        self.hidden_dim = self.bert.embeddings.word_embeddings.embedding_dim
        # The only new parameters -- the classifier:
        self.classifier_layer = nn.Linear(
            self.hidden_dim, self.n_classes)
        self.layers = self.bert.model.encoder.layer
        
    def no_IIT_forward(self, indices, attention_mask):
        reps = self.bert(
            indices, attention_mask=mask)
        return self.classifier_layer(reps.pooler_output)
    
    def forward(self, X):
        base_indices, base_mask,source_indices,source_mask, coord_ids = [X[:,0,:].squeeze(1) for j in range(5)]
        get = self.id_to_coords[int(coord_ids.flatten()[0])]
        base = base.type(torch.FloatTensor).to(self.device)
        source = source.type(torch.FloatTensor).to(self.device)
        
        self.activation = dict()
        handlers = self._get_set(get,None)
        source_logits = self.model.no_IIT_forward(source_indices,source_mask)
        for handler in handlers:
            handler.remove()

        base_logits = self.model.no_IIT_forward(base_indices, base_mask)
        set = {k:get[k] for k in get}
        set["intervention"] = self.activation[f'{get["layer"]}-{get["start"]}-{get["end"]}']
        handlers = self._get_set(get, set)
        counterfactual_logits = self.model.no_IIT_forward(base_indices, base_mask)
        for handler in handlers:
            handler.remove()

        return counterfactual_logits, base_logits

    

class HfBertClassifierIIT(TorchModelBase):
    def __init__(self, *args, **kwargs):
        self.weights_name = kwargs["weights_name"]
        self.model = kwargs["model"]
        self.tokenizer = BertTokenizer.from_pretrained(self.weights_name)
        super().__init__(*args, **kwargs)
        self.params += ['weights_name']

    def build_graph(self):
        return HfBertClassifierModelIIT(self.n_classes_,weights_name=self.weights_name)

    def build_dataset(self, base, source, base_y, IIT_y, coord_ids):
        base_data = self.tokenizer.batch_encode_plus(
            base,
            max_length=None,
            add_special_tokens=True,
            padding='longest',
            return_attention_mask=True)
        source_data = self.tokenizer.batch_encode_plus(
            source,
            max_length=None,
            add_special_tokens=True,
            padding='longest',
            return_attention_mask=True)
        base_indices = torch.tensor(base_data['input_ids'])
        base_mask = torch.tensor(base_data['attention_mask'])
        source_indices = torch.tensor(source_data['input_ids'])
        source_mask = torch.tensor(source_data['attention_mask'])
        
        self.classes_ = sorted(set(base_y))
        self.n_classes_ = len(self.classes_)
        class2index = dict(zip(self.classes_, range(self.n_classes_)))
        base_y = [class2index[label] for label in base_y]
        base_y = torch.tensor(base_y)

        self.classes_ = sorted(set(IIT_y))
        self.n_classes_ = len(self.classes_)
        class2index = dict(zip(self.classes_, range(self.n_classes_)))
        IIT_y = [class2index[label] for label in base_y]
        IIT_y = torch.tensor(IIT_y)
        
        bigX = torch.stack((base_indices, base_mask,source_indices,source_mask, coord_ids.unsqueeze(1).expand(-1, X.shape[1])), dim=1)
        bigy = torch.stack((IIT_y, base_y), dim=1)
        return dataset


In [None]:
DATA_HOME = os.path.join("data", "nlidata")

SNLI_HOME = os.path.join(DATA_HOME, "snli_1.0")

MULTINLI_HOME = os.path.join(DATA_HOME, "multinli_1.0")

nli.SNLITrainReader(SNLI_HOME, samp_percentage=0.10, random_state=42)

In [8]:
model = HfBertClassifierIIT(
    weights_name='bert-base-cased',
    model =BertModel.from_pretrained("bert-base-cased"),
    batch_size=8,  # Small batches to avoid memory overload.
    max_iter=1,  # We'll search based on 1 iteration for efficiency.
    n_iter_no_change=5,   # Early-stopping params are for the
    early_stopping=True)  # final evaluation.

param_grid = {
    'gradient_accumulation_steps': [1, 4, 8],
    'eta': [0.00005, 0.0001, 0.001],
    'hidden_dim': [100, 200, 300]}

X_base_test, X_source_test, y_base_test, y_IIT_test, interventions = get_IIT_MoNLI_dataset(os.path.join("data", "MoNLI"))

model.fit(X_base_test, X_source_test, y_base_test, y_IIT_test, interventions)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


NameError: name 'get_IIT_MoNLI_dataset' is not defined