<a href="https://colab.research.google.com/github/kryuchkovdm/Distillation/blob/master/methods/Utils.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch.nn.utils.prune as prune
from contextlib import contextmanager

In [None]:
def get_intent_dataset(path):
  intents_list = []
  dialogs_list = []
  dialog = ''
  files = os.listdir(path)
  for dialogs in tqdm(files):
    with open(path+dialogs, "r") as read_file:
      data = json.load(read_file)
    if len(data)%2==1:
      data = data[0:len(data)-1]
    for i in data:
      r = json.dumps(i)
      loaded_r = json.loads(r)
      try:
        for num, intent in enumerate(loaded_r['turns']):
          if num%2!=0:
            dialog = dialog + ' ' + intent['utterance']
            dialogs_list.append(dialog)
          else:
            dialog = ''
            intents_list.append(intent['frames'][0]['state']['active_intent'])
            dialog = intent['utterance']
      except:
        continue
  return dialogs_list, intents_list

In [None]:
@memory.cache(ignore=["teacher"]) 
def dataframe_to_dataset(df, teacher):
    max_len = 128
    features = tokenizer.batch_encode_plus(df.title.values.tolist(),
                                           max_length=max_len,
                                           pad_to_max_length=True,
                                           return_attention_mask=True,
                                           return_token_type_ids=True,
                                           return_tensors="pt")

    pre_dataset = TensorDataset(features["input_ids"],
                                features["attention_mask"],
                                features["token_type_ids"])
    teacher.to(device)
    teacher.eval()
    teacher_predictions = []
    for batch in tqdm(DataLoader(pre_dataset, batch_size=32, shuffle=False)):
        batch = tuple([b.to(device) for b in batch])
        inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
        if teacher.base_model_prefix == "bert":
            inputs["token_type_ids"] = batch[2]
        with torch.no_grad():
            outputs = teacher(**inputs)
        teacher_predictions.append(outputs[0].to(torch.device("cpu")))  # put back on CPU

    dataset = TensorDataset(features["input_ids"],
                            features["attention_mask"],
                            features["token_type_ids"],
                            torch.tensor(df.label.astype("int").to_numpy(), dtype=torch.long),
                            torch.cat(teacher_predictions, axis=0))


    return dataset

In [None]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [None]:
@contextmanager
def single_thread():  
    num = torch.get_num_threads()
    torch.set_num_threads(1)
    yield
    torch.set_num_threads(num)

In [None]:
def predict(model,text):
        max_len = 128
        l = []
        l.append(text)
        features = tokenizer.batch_encode_plus(l,
                                           max_length=max_len,
                                           pad_to_max_length=True,
                                           return_attention_mask=True,
                                           return_token_type_ids=True,
                                           return_tensors="pt")
        output = model(features['input_ids'].to(device))
        pred = F.softmax( output, dim=1 )
        ids = pred.argsort(1)[0].tolist()
        ids.reverse()
        return [(category_index_reverce[id], pred.tolist()[0][id]) for id in ids[:3]]

In [None]:
def calc_weights(model):
    result = 0
    for layer in model.children():
        result += len(layer.weight.reshape(-1))
    return result

In [None]:
def calc_pruned_weights(model):
    result = 0
    for layer in model.children():
        result += torch.sum(layer._mask.weight.reshape(-1))
    return int(result.item())

In [None]:
def calc_pytorch_weights(model):
    result = 0
    for layer in model.children():
      try:
        if hasattr(layer, 'weight_mask'):
            result += int(torch.sum(layer.weight_mask.reshape(-1)).item())
        else:
            result += len(layer.weight.reshape(-1))
      except:
        continue
    return result