# Persuasion Techniques in Text of Memes - Inference with hierarchical models

## Enironment Setup

##### Disk Setup

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
folder_name = "/content/drive/MyDrive/persuasion_technique_detection/"

##### Imports

In [3]:
!pip install transformers datasets wandb evaluate accelerate -qU sklearn_hierarchical_classification

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m24.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.6/190.6 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m252.8/252.8 kB[0m [31m2

In [32]:
import json
import numpy as np
import pandas as pd
import os
import torch
import subprocess
import json
import warnings

In [5]:
from sklearn.preprocessing import MultiLabelBinarizer
from transformers import AutoModelForSequenceClassification
from datasets import load_dataset
from transformers import Trainer
from transformers import AutoTokenizer, DataCollatorWithPadding

In [6]:
AVAIL_GPUS = 0
if torch.cuda.is_available():
    device = torch.device("cuda")
    AVAIL_GPUS = torch.cuda.device_count()
    print(f'There are {AVAIL_GPUS} GPU(s) available.')
    print('Device name:', torch.cuda.get_device_name(0))

else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) available.
Device name: Tesla T4


## Login to WandB

In [7]:
import wandb
import os
wandb.login()
# setup wandb environment variables
os.environ['WANDB_PROJECT'] = "subtask1_transformer_encoder_classification"
os.environ['WANDB_ENTITY'] = "tumnlp"
os.environ["WANDB_LOG_MODEL"]= "end"

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


## Pre-trained Transformer Name

In [8]:
checkpoint = "bert-base-cased"
#checkpoint = "xlm-roberta-base"
#checkpoint = "xlnet-base-cased"
#checkpoint = "microsoft/deberta-v3-base"
#checkpoint = "albert-base-v2"


## Data Preprocessing

In [9]:
val_path=folder_name+"data/subtask1/validation.json"
test_path=folder_name+"data/subtask1/dev_unlabeled.json"

val_files={"val":val_path}
test_files={"test":test_path}

dataset_val=load_dataset("json",data_files=val_files)
dataset_test=load_dataset("json",data_files=test_files)

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating val split: 0 examples [00:00, ? examples/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [10]:
unique_labels = ['Appeal to authority',
 'Bandwagon',
 'Causal Oversimplification',
 'Smears',
 'Flag-waving',
 'Black-and-white Fallacy/Dictatorship',
 'Slogans',
 'Repetition',
 'Obfuscation, Intentional vagueness, Confusion',
 'Name calling/Labeling',
 'Exaggeration/Minimisation',
 'Whataboutism',
 'Doubt',
 "Misrepresentation of Someone's Position (Straw Man)",
 'Presenting Irrelevant Data (Red Herring)',
 'Thought-terminating cliché',
 'Loaded Language',
 'Reductio ad hitlerum',
 'Appeal to fear/prejudice',
 'Glittering generalities (Virtue)']

In [11]:
persuasion_unique=[["Ethos","Pathos","Logos"]]
ethos_unique=[["Ad Hominem","Bandwagon","Appeal to authority","Glittering generalities (Virtue)"]]
pathos_unique=[["Exaggeration/Minimisation","Loaded Language","Flag-waving","Appeal to fear/prejudice"]]
logos_unique=[["Justification","Reasoning","Repetition","Obfuscation, Intentional vagueness, Confusion"]]
ad_hominem_unique=[["Name calling/Labeling","Doubt","Smears","Reductio ad hitlerum","Whataboutism"]]
justification_unique=[["Bandwagon","Appeal to authority","Flag-waving","Appeal to fear/prejudice","Slogans"]]
reasoning_unique=[["Distraction","Simplification"]]
distraction_unique=[["Misrepresentation of Someone's Position (Straw Man)","Presenting Irrelevant Data (Red Herring)","Whataboutism"]]
simplification_unique=[["Causal Oversimplification","Black-and-white Fallacy/Dictatorship","Thought-terminating cliché"]]


### Preprocess Multi-Labels

In [12]:
mlb_persuasion = MultiLabelBinarizer()
mlb_persuasion.fit(persuasion_unique)

mlb_ethos = MultiLabelBinarizer()
mlb_ethos.fit(ethos_unique)

mlb_pathos = MultiLabelBinarizer()
mlb_pathos.fit(pathos_unique)

mlb_logos = MultiLabelBinarizer()
mlb_logos.fit(logos_unique)

mlb_ad_hominem = MultiLabelBinarizer()
mlb_ad_hominem.fit(ad_hominem_unique)

mlb_justification = MultiLabelBinarizer()
mlb_justification.fit(justification_unique)

mlb_reasoning=MultiLabelBinarizer()
mlb_reasoning.fit(reasoning_unique)

mlb_distraction=MultiLabelBinarizer()
mlb_distraction.fit(distraction_unique)

mlb_simplification=MultiLabelBinarizer()
mlb_simplification.fit(simplification_unique)

## Evaluation

### Predict validation set and create output json file

In [13]:
def write_json(path,data,test=False):
  if not isinstance(data, dict):
    data = data.to_dict("records")

  with open(path, "w") as output_file:
      json.dump(data, output_file, indent=2,ensure_ascii=False)

In [14]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
def tokenize(examples):
    encoding = tokenizer(examples["text"], truncation=True,padding=True)
    return encoding


tokenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

In [15]:
def get_preds(mlb,predicted_logits,threshold):
  sigmoid = torch.nn.Sigmoid()
  predicted_logits_tensor = torch.from_numpy(predicted_logits)
  probs = sigmoid(predicted_logits_tensor.squeeze())
  # Get predictions that have higher probability than threshold
  predictions = (probs > threshold).int()
  mask=predictions>0
  return list(zip(list(map(list,mlb.inverse_transform(predictions))),[probs[i][mask[i].bool()].tolist() for i in range(probs.size(0))]))

In [28]:
project_name="subtask1_transformer_encoder_classification"

def return_trainer(model_name,unique_labels):
  api = wandb.Api()
  artifact=api.artifact(model_name)
  model_dir=artifact.download()
  model = AutoModelForSequenceClassification.from_pretrained(model_dir, num_labels=len(unique_labels[0]), problem_type="multi_label_classification")
  trainer = Trainer(model=model)
  if torch.cuda.is_available():
    trainer.model = model.cuda()
  return trainer

In [29]:
### Input from summary file - only save the name of the model, not the path
model_nodes= {
  "persuasion": "model-bert-base-cased-persuasion-memes_0.5threshold_5e-05learningRate:v1",
  "ethos": "model-bert-base-cased-ethos-memes_0.5threshold_5e-05learningRate:v1",
  "pathos": "model-bert-base-cased-pathos-memes_0.5threshold_5e-05learningRate:v1",
  "logos": "model-bert-base-cased-logos-memes_0.5threshold_5e-05learningRate:v0",
  "ad_hominem": "model-bert-base-cased-ad_hominem-memes_0.5threshold_5e-05learningRate:v0",
  "justification": "model-bert-base-cased-justification-memes_0.5threshold_5e-05learningRate:v0",
  "reasoning": "model-bert-base-cased-reasoning-memes_0.5threshold_5e-05learningRate:v0",
  "simplification": "model-bert-base-cased-simplification-memes_0.5threshold_5e-05learningRate:v0",
  "distraction": "model-bert-base-cased-distraction-memes_0.5threshold_5e-05learningRate:v0"
}

threshold_nodes= {
  "persuasion": 0.5,
  "ethos": 0.5,
  "pathos": 0.5,
  "logos": 0.5,
  "ad_hominem": 0.5,
  "justification": 0.5,
  "reasoning": 0.5,
  "simplification": 0.5,
  "distraction": 0.5
}



### Evaluate validation set

In [30]:
prediction_set=dataset_val["val"]
prediction_set=prediction_set.remove_columns(["labels"])

In [33]:
trainer=return_trainer(model_nodes["persuasion"],persuasion_unique)

threshold=threshold_nodes['persuasion']
prediction_set_tokenized=prediction_set.map(tokenize ,batched=True)

preds=get_preds(mlb_persuasion,trainer.predict(prediction_set_tokenized).predictions,threshold)
final_ds=dict(zip(prediction_set_tokenized["id"],preds))

model_seq=["Ethos","Pathos","Logos","Ad Hominem","Justification","Reasoning","Simplification","Distraction"]

for node in model_seq:


  ids=[id for id,val in final_ds.items() if (len(val[0])>0) and (node in val[0])]

  prev_labels=[final_ds[x] for x in ids]
  new_labels=[]
  curr_set=prediction_set_tokenized.filter(lambda x: x["id"] in ids)

  print(f"Node: {node}, curr_set: {curr_set}")
  if len(ids) == 0:
    warnings.warn(f"No predictions with Label {node} made!!!!!")
    continue

  if node=="Ethos":
    threshold=threshold_nodes['ethos']
    trainer=return_trainer(model_nodes["ethos"],ethos_unique)
    preds=get_preds(mlb_ethos,trainer.predict(curr_set).predictions,threshold)

    for i in range(len(preds)):

      if len(preds[i][0])>0:
        t=prev_labels[i][0].index("Ethos")
        _=prev_labels[i][1].pop(t)
        prev_labels[i][0].remove("Ethos")
        parent_labels=prev_labels[i][0]
        parent_probs=prev_labels[i][1]

        new_labels.append((parent_labels+preds[i][0],parent_probs+preds[i][1]))
      else:
        new_labels.append(prev_labels[i])



  elif node=="Pathos":
    threshold=threshold_nodes['pathos']
    trainer=return_trainer(model_nodes["pathos"],pathos_unique)
    preds=get_preds(mlb_pathos,trainer.predict(curr_set).predictions,threshold)

    for i in range(len(preds)):
      if len(preds[i][0])>0:
        t=prev_labels[i][0].index("Pathos")
        _=prev_labels[i][1].pop(t)
        prev_labels[i][0].remove("Pathos")
        parent_labels=prev_labels[i][0]
        parent_probs=prev_labels[i][1]

        new_labels.append((parent_labels+preds[i][0],parent_probs+preds[i][1]))
      else:
        new_labels.append(prev_labels[i])


  elif node=="Logos":
    threshold=threshold_nodes['logos']
    trainer=return_trainer(model_nodes["logos"],logos_unique)
    preds=get_preds(mlb_logos,trainer.predict(curr_set).predictions,threshold)

    for i in range(len(preds)):
      if len(preds[i][0])>0:
        t=prev_labels[i][0].index("Logos")
        _=prev_labels[i][1].pop(t)
        prev_labels[i][0].remove("Logos")
        parent_labels=prev_labels[i][0]
        parent_probs=prev_labels[i][1]

        new_labels.append((parent_labels+preds[i][0],parent_probs+preds[i][1]))
      else:
        new_labels.append(prev_labels[i])

  elif node=="Ad Hominem":
    threshold=0.4
    trainer=return_trainer(model_nodes["ad_hominem"],ad_hominem_unique)
    preds=get_preds(mlb_ad_hominem,trainer.predict(curr_set).predictions,threshold)

    for i in range(len(preds)):
      if len(preds[i][0])>0:
        t=prev_labels[i][0].index("Ad Hominem")
        _=prev_labels[i][1].pop(t)
        prev_labels[i][0].remove("Ad Hominem")
        parent_labels=prev_labels[i][0]
        parent_probs=prev_labels[i][1]

        new_labels.append((parent_labels+preds[i][0],parent_probs+preds[i][1]))
      else:
        new_labels.append(prev_labels[i])

  elif node=="Justification":
    threshold=threshold_nodes['justification']
    trainer=return_trainer(model_nodes["justification"],justification_unique)
    preds=get_preds(mlb_justification,trainer.predict(curr_set).predictions,threshold)

    for i in range(len(preds)):
      if len(preds[i][0])>0:
        t=prev_labels[i][0].index("Justification")
        _=prev_labels[i][1].pop(t)
        prev_labels[i][0].remove("Justification")
        parent_labels=prev_labels[i][0]
        parent_probs=prev_labels[i][1]

        new_labels.append((parent_labels+preds[i][0],parent_probs+preds[i][1]))
      else:
        new_labels.append(prev_labels[i])


  elif node=="Reasoning":
    threshold=threshold_nodes['reasoning']
    trainer=return_trainer(model_nodes["reasoning"],reasoning_unique)
    preds=get_preds(mlb_reasoning,trainer.predict(curr_set).predictions,threshold)

    for i in range(len(preds)):
      if len(preds[i][0])>0:
        t=prev_labels[i][0].index("Reasoning")
        _=prev_labels[i][1].pop(t)
        prev_labels[i][0].remove("Reasoning")
        parent_labels=prev_labels[i][0]
        parent_probs=prev_labels[i][1]

        new_labels.append((parent_labels+preds[i][0],parent_probs+preds[i][1]))
      else:
        new_labels.append(prev_labels[i])

  elif node=="Simplification":
    threshold=threshold_nodes['simplification']
    trainer=return_trainer(model_nodes["simplification"],simplification_unique)
    preds=get_preds(mlb_simplification,trainer.predict(curr_set).predictions,threshold)

    for i in range(len(preds)):
      if len(preds[i][0])>0:
        t=prev_labels[i][0].index("Simplification")
        _=prev_labels[i][1].pop(t)
        prev_labels[i][0].remove("Simplification")
        parent_labels=prev_labels[i][0]
        parent_probs=prev_labels[i][1]

        new_labels.append((parent_labels+preds[i][0],parent_probs+preds[i][1]))
      else:
        new_labels.append(prev_labels[i])

  elif node=="Distraction":
    threshold=threshold_nodes['distraction']
    trainer=return_trainer(model_nodes["distraction"],distraction_unique)
    preds=get_preds(mlb_distraction,trainer.predict(curr_set).predictions,threshold)

    for i in range(len(preds)):
      if len(preds[i][0])>0:
        t=prev_labels[i][0].index("Distraction")
        _=prev_labels[i][1].pop(t)
        prev_labels[i][0].remove("Distraction")
        parent_labels=prev_labels[i][0]
        parent_probs=prev_labels[i][1]

        new_labels.append((parent_labels+preds[i][0],parent_probs+preds[i][1]))
      else:
        new_labels.append(prev_labels[i])

  for i in range(len(ids)):
    final_ds[ids[i]]=new_labels[i]


[34m[1mwandb[0m: Downloading large artifact model-bert-base-cased-persuasion-memes_0.5threshold_5e-05learningRate:v1, 414.05MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:1.2


Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Filter:   0%|          | 0/500 [00:00<?, ? examples/s]

Node: Ethos, curr_set: Dataset({
    features: ['id', 'link', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 352
})


[34m[1mwandb[0m: Downloading large artifact model-bert-base-cased-ethos-memes_0.5threshold_5e-05learningRate:v1, 414.05MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:2.7


Filter:   0%|          | 0/500 [00:00<?, ? examples/s]

Node: Pathos, curr_set: Dataset({
    features: ['id', 'link', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 155
})


[34m[1mwandb[0m: Downloading large artifact model-bert-base-cased-pathos-memes_0.5threshold_5e-05learningRate:v1, 414.05MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:2.6


Filter:   0%|          | 0/500 [00:00<?, ? examples/s]

Node: Logos, curr_set: Dataset({
    features: ['id', 'link', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 319
})


[34m[1mwandb[0m: Downloading large artifact model-bert-base-cased-logos-memes_0.5threshold_5e-05learningRate:v0, 414.05MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:2.5


Filter:   0%|          | 0/500 [00:00<?, ? examples/s]

Node: Ad Hominem, curr_set: Dataset({
    features: ['id', 'link', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 313
})


[34m[1mwandb[0m: Downloading large artifact model-bert-base-cased-ad_hominem-memes_0.5threshold_5e-05learningRate:v0, 414.06MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:2.4


Filter:   0%|          | 0/500 [00:00<?, ? examples/s]

Node: Justification, curr_set: Dataset({
    features: ['id', 'link', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 206
})


[34m[1mwandb[0m: Downloading large artifact model-bert-base-cased-justification-memes_0.5threshold_5e-05learningRate:v0, 414.06MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:2.6


Filter:   0%|          | 0/500 [00:00<?, ? examples/s]

Node: Reasoning, curr_set: Dataset({
    features: ['id', 'link', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 135
})


[34m[1mwandb[0m: Downloading large artifact model-bert-base-cased-reasoning-memes_0.5threshold_5e-05learningRate:v0, 414.05MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:2.4


Node: Simplification, curr_set: Dataset({
    features: ['id', 'link', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 135
})


[34m[1mwandb[0m: Downloading large artifact model-bert-base-cased-simplification-memes_0.5threshold_5e-05learningRate:v0, 414.05MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:2.4


Filter:   0%|          | 0/500 [00:00<?, ? examples/s]

Node: Distraction, curr_set: Dataset({
    features: ['id', 'link', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 0
})




In [34]:
top_k=3
final_df=dict()
for k,v in final_ds.items():

  pred_prob=list(zip(v[0],v[1]))
  pred_prob.sort(key=lambda x: -x[1])

  tmp=[]
  i=0
  for el in pred_prob:
    if i==top_k:
      break
    if el[0] not in tmp:
      tmp.append(el[0])
      i+=1

  final_df[k]=tmp

In [35]:
pred_df=pd.DataFrame.from_dict({"id":final_df.keys(),"labels":final_df.values()})

summary_dir_path = folder_name + "subtask1/summary_inference_" + checkpoint + "/"
val_pred_file=summary_dir_path + "val_pred.json"
if not os.path.exists(summary_dir_path):
  os.makedirs(summary_dir_path)

write_json(val_pred_file,pred_df)

### Evaluate using the scorer script

In [36]:
scorer = folder_name + "subtask1/subtask_1_2a.py"
command = f'python3 {scorer} --gold_file_path {val_path} --pred_file_path {val_pred_file}'

result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, text=True)
output = result.stdout.strip()

parts = output.split('\t')
f1_h = parts[0].split('=')[1]
prec_h = parts[1].split('=')[1]
rec_h = parts[2].split('=')[1]

f1_h = float(f1_h)
prec_h = float(prec_h)
rec_h = float(rec_h)

hierarchical_metrics = {"f1_hierarchical": f1_h, "precision_hierarchical": prec_h, "recall_hierarchical": rec_h}
hierarchical_metrics

{'f1_hierarchical': 0.53559,
 'precision_hierarchical': 0.50651,
 'recall_hierarchical': 0.5682}

### Create dev output file

In [38]:
prediction_set=dataset_test["test"]

In [40]:
trainer=return_trainer(model_nodes["persuasion"],persuasion_unique)

threshold=threshold_nodes['persuasion']
prediction_set_tokenized=prediction_set.map(tokenize ,batched=True)

preds=get_preds(mlb_persuasion,trainer.predict(prediction_set_tokenized).predictions,threshold)
final_ds=dict(zip(prediction_set_tokenized["id"],preds))

model_seq=["Ethos","Pathos","Logos","Ad Hominem","Justification","Reasoning","Simplification","Distraction"]

for node in model_seq:


  ids=[id for id,val in final_ds.items() if (len(val[0])>0) and (node in val[0])]

  prev_labels=[final_ds[x] for x in ids]
  new_labels=[]
  curr_set=prediction_set_tokenized.filter(lambda x: x["id"] in ids)

  print(f"Node: {node}, curr_set: {curr_set}")
  if len(ids) == 0:
    warnings.warn(f"No predictions with Label {node} made!!!!!")
    continue

  if node=="Ethos":
    threshold=threshold_nodes['ethos']
    trainer=return_trainer(model_nodes["ethos"],ethos_unique)
    preds=get_preds(mlb_ethos,trainer.predict(curr_set).predictions,threshold)

    for i in range(len(preds)):

      if len(preds[i][0])>0:
        t=prev_labels[i][0].index("Ethos")
        _=prev_labels[i][1].pop(t)
        prev_labels[i][0].remove("Ethos")
        parent_labels=prev_labels[i][0]
        parent_probs=prev_labels[i][1]

        new_labels.append((parent_labels+preds[i][0],parent_probs+preds[i][1]))
      else:
        new_labels.append(prev_labels[i])



  elif node=="Pathos":
    threshold=threshold_nodes['pathos']
    trainer=return_trainer(model_nodes["pathos"],pathos_unique)
    preds=get_preds(mlb_pathos,trainer.predict(curr_set).predictions,threshold)

    for i in range(len(preds)):
      if len(preds[i][0])>0:
        t=prev_labels[i][0].index("Pathos")
        _=prev_labels[i][1].pop(t)
        prev_labels[i][0].remove("Pathos")
        parent_labels=prev_labels[i][0]
        parent_probs=prev_labels[i][1]

        new_labels.append((parent_labels+preds[i][0],parent_probs+preds[i][1]))
      else:
        new_labels.append(prev_labels[i])


  elif node=="Logos":
    threshold=threshold_nodes['logos']
    trainer=return_trainer(model_nodes["logos"],logos_unique)
    preds=get_preds(mlb_logos,trainer.predict(curr_set).predictions,threshold)

    for i in range(len(preds)):
      if len(preds[i][0])>0:
        t=prev_labels[i][0].index("Logos")
        _=prev_labels[i][1].pop(t)
        prev_labels[i][0].remove("Logos")
        parent_labels=prev_labels[i][0]
        parent_probs=prev_labels[i][1]

        new_labels.append((parent_labels+preds[i][0],parent_probs+preds[i][1]))
      else:
        new_labels.append(prev_labels[i])

  elif node=="Ad Hominem":
    threshold=0.4
    trainer=return_trainer(model_nodes["ad_hominem"],ad_hominem_unique)
    preds=get_preds(mlb_ad_hominem,trainer.predict(curr_set).predictions,threshold)

    for i in range(len(preds)):
      if len(preds[i][0])>0:
        t=prev_labels[i][0].index("Ad Hominem")
        _=prev_labels[i][1].pop(t)
        prev_labels[i][0].remove("Ad Hominem")
        parent_labels=prev_labels[i][0]
        parent_probs=prev_labels[i][1]

        new_labels.append((parent_labels+preds[i][0],parent_probs+preds[i][1]))
      else:
        new_labels.append(prev_labels[i])

  elif node=="Justification":
    threshold=threshold_nodes['justification']
    trainer=return_trainer(model_nodes["justification"],justification_unique)
    preds=get_preds(mlb_justification,trainer.predict(curr_set).predictions,threshold)

    for i in range(len(preds)):
      if len(preds[i][0])>0:
        t=prev_labels[i][0].index("Justification")
        _=prev_labels[i][1].pop(t)
        prev_labels[i][0].remove("Justification")
        parent_labels=prev_labels[i][0]
        parent_probs=prev_labels[i][1]

        new_labels.append((parent_labels+preds[i][0],parent_probs+preds[i][1]))
      else:
        new_labels.append(prev_labels[i])


  elif node=="Reasoning":
    threshold=threshold_nodes['reasoning']
    trainer=return_trainer(model_nodes["reasoning"],reasoning_unique)
    preds=get_preds(mlb_reasoning,trainer.predict(curr_set).predictions,threshold)

    for i in range(len(preds)):
      if len(preds[i][0])>0:
        t=prev_labels[i][0].index("Reasoning")
        _=prev_labels[i][1].pop(t)
        prev_labels[i][0].remove("Reasoning")
        parent_labels=prev_labels[i][0]
        parent_probs=prev_labels[i][1]

        new_labels.append((parent_labels+preds[i][0],parent_probs+preds[i][1]))
      else:
        new_labels.append(prev_labels[i])

  elif node=="Simplification":
    threshold=threshold_nodes['simplification']
    trainer=return_trainer(model_nodes["simplification"],simplification_unique)
    preds=get_preds(mlb_simplification,trainer.predict(curr_set).predictions,threshold)

    for i in range(len(preds)):
      if len(preds[i][0])>0:
        t=prev_labels[i][0].index("Simplification")
        _=prev_labels[i][1].pop(t)
        prev_labels[i][0].remove("Simplification")
        parent_labels=prev_labels[i][0]
        parent_probs=prev_labels[i][1]

        new_labels.append((parent_labels+preds[i][0],parent_probs+preds[i][1]))
      else:
        new_labels.append(prev_labels[i])

  elif node=="Distraction":
    threshold=threshold_nodes['distraction']
    trainer=return_trainer(model_nodes["distraction"],distraction_unique)
    preds=get_preds(mlb_distraction,trainer.predict(curr_set).predictions,threshold)

    for i in range(len(preds)):
      if len(preds[i][0])>0:
        t=prev_labels[i][0].index("Distraction")
        _=prev_labels[i][1].pop(t)
        prev_labels[i][0].remove("Distraction")
        parent_labels=prev_labels[i][0]
        parent_probs=prev_labels[i][1]

        new_labels.append((parent_labels+preds[i][0],parent_probs+preds[i][1]))
      else:
        new_labels.append(prev_labels[i])

  for i in range(len(ids)):
    final_ds[ids[i]]=new_labels[i]


[34m[1mwandb[0m: Downloading large artifact model-bert-base-cased-persuasion-memes_0.5threshold_5e-05learningRate:v1, 414.05MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:2.8


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1000 [00:00<?, ? examples/s]

Node: Ethos, curr_set: Dataset({
    features: ['id', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 732
})


[34m[1mwandb[0m: Downloading large artifact model-bert-base-cased-ethos-memes_0.5threshold_5e-05learningRate:v1, 414.05MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:1.2


Filter:   0%|          | 0/1000 [00:00<?, ? examples/s]

Node: Pathos, curr_set: Dataset({
    features: ['id', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 288
})


[34m[1mwandb[0m: Downloading large artifact model-bert-base-cased-pathos-memes_0.5threshold_5e-05learningRate:v1, 414.05MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:1.0


Filter:   0%|          | 0/1000 [00:00<?, ? examples/s]

Node: Logos, curr_set: Dataset({
    features: ['id', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 653
})


[34m[1mwandb[0m: Downloading large artifact model-bert-base-cased-logos-memes_0.5threshold_5e-05learningRate:v0, 414.05MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:1.0


Filter:   0%|          | 0/1000 [00:00<?, ? examples/s]

Node: Ad Hominem, curr_set: Dataset({
    features: ['id', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 636
})


[34m[1mwandb[0m: Downloading large artifact model-bert-base-cased-ad_hominem-memes_0.5threshold_5e-05learningRate:v0, 414.06MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:1.0


Filter:   0%|          | 0/1000 [00:00<?, ? examples/s]

Node: Justification, curr_set: Dataset({
    features: ['id', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 408
})


[34m[1mwandb[0m: Downloading large artifact model-bert-base-cased-justification-memes_0.5threshold_5e-05learningRate:v0, 414.06MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:1.0


Filter:   0%|          | 0/1000 [00:00<?, ? examples/s]

Node: Reasoning, curr_set: Dataset({
    features: ['id', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 277
})


[34m[1mwandb[0m: Downloading large artifact model-bert-base-cased-reasoning-memes_0.5threshold_5e-05learningRate:v0, 414.05MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:1.0


Node: Simplification, curr_set: Dataset({
    features: ['id', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 277
})


[34m[1mwandb[0m: Downloading large artifact model-bert-base-cased-simplification-memes_0.5threshold_5e-05learningRate:v0, 414.05MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:1.0


Filter:   0%|          | 0/1000 [00:00<?, ? examples/s]

Node: Distraction, curr_set: Dataset({
    features: ['id', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 0
})




In [41]:
top_k=3
final_df=dict()
for k,v in final_ds.items():

  pred_prob=list(zip(v[0],v[1]))
  pred_prob.sort(key=lambda x: -x[1])

  tmp=[]
  i=0
  for el in pred_prob:
    if i==top_k:
      break
    if el[0] not in tmp:
      tmp.append(el[0])
      i+=1

  final_df[k]=tmp

In [42]:
pred_df=pd.DataFrame.from_dict({"id":final_df.keys(),"labels":final_df.values()})

summary_dir_path = folder_name + "subtask1/summary_inference_" + checkpoint + "/"
dev_pred_file=summary_dir_path + "dev_pred.json"
if not os.path.exists(summary_dir_path):
  os.makedirs(summary_dir_path)

write_json(dev_pred_file,pred_df)