In [None]:
import os
import argparse
import torch
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch import nn, optim
from data import HierDataModule
from data import infer_preprocess
from ERDE import ERDE_sample
from model import HierClassifier
from transformers import AutoTokenizer
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from collections import defaultdict, Counter
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from model import HierClassifier
from ERDE import ERDE_chunk
import xml.dom.minidom
import string

In [None]:
clf = HierClassifier.load_from_checkpoint(
    "[pretrained ckpt]"
)
clf.eval()
tokenizer = AutoTokenizer.from_pretrained(clf.model_type)
max_len = clf.hparams.max_len
max_posts = 16
clf.cuda()
None

In [None]:
with open("processed/miniLM_L6_embs.pkl", "rb") as f:
    data = pickle.load(f)

train_posts = data["train_posts"]
train_mappings = data["train_mappings"]
train_tags = data["train_labels"]
train_embs = data["train_embs"]
test_posts = data["test_posts"]
test_mappings = data["test_mappings"]
test_tags = data["test_labels"]
test_embs = data["test_embs"]

In [None]:
sbert = SentenceTransformer('paraphrase-MiniLM-L6-v2').cuda()

In [None]:
depression_texts = [
    "I feel depressed.",
    "I am diagnosed with depression.",
    "I am treating my depression."
]
questionaire_single = [
    "I feel sad.",
    "I am discouraged about my future.",
    "I always fail.",
    "I don't get pleasure from things.",
    "I feel quite guilty.",
    "I expected to be punished.",
    "I am disappointed in myself.",
    "I always criticize myself for my faults.",
    "I have thoughts of killing myself.",
    "I always cry.",
    "I am hard to stay still.",
    "It's hard to get interested in things.",
    "I have trouble making decisions.",
    "I feel worthless.",
    "I don't have energy to do things.",
    "I have changes in my sleeping pattern.",
    "I am always irritable.",
    "I have changes in my appetite.",
    "I feel hard to concentrate on things.",
    "I am too tired to do things.",
    "I have lost my interest in sex."
]

template_embeddings = sbert.encode(depression_texts+questionaire_single)
template_embeddings.shape

In [None]:
sample_pred_probas = []
user_basis = []
num_updates_users = []
num_posts_users = []
for mappings in tqdm(test_mappings, total=len(test_mappings)):
    user_posts = [test_posts[i] for i in mappings[::-1]]
    pred_probas = []
    posts_bank = []
    embedding_bank = None
    scores_bank = []
    basis_bank = []
    num_updates = 0
    for pid, new_post in enumerate(user_posts):
        # new_post = ""
        # new_emb = sbert.encode(new_post).reshape(1, -1)
        new_emb = test_embs[mappings[pid]].reshape(1, -1)
        new_scores = cosine_similarity(new_emb, template_embeddings)[0]
        best_template_id = new_scores.argmax()
        new_score = new_scores[best_template_id]
        # take all new posts before capacity is all used
        if len(posts_bank) < max_posts:
            posts_bank.insert(0, new_post)
            scores_bank.insert(0, new_score)
            basis_bank.insert(0, best_template_id)
            batch = infer_preprocess(tokenizer, posts_bank, max_len)
            for k, v in batch.items():
                batch[k] = v.cuda()
            with torch.no_grad():
                logits, attn_score = clf([batch])
            num_updates += 1
            proba = torch.sigmoid(logits).detach().cpu().item()
            pred_probas.append(proba)
            continue
        min_id = np.argmin(scores_bank)
        if new_score >= scores_bank[min_id]:
            del posts_bank[min_id]
            del scores_bank[min_id]
            del basis_bank[mid_id]
            posts_bank.insert(0, new_post)
            scores_bank.insert(0, new_score)
            basis_bank.insert(0, best_template_id)
            # make prediction
            batch = infer_preprocess(tokenizer, posts_bank, max_len)
            for k, v in batch.items():
                batch[k] = v.cuda()
            with torch.no_grad():
                logits, attn_score = clf([batch])
            num_updates += 1
            proba = torch.sigmoid(logits).detach().cpu().item()
            pred_probas.append(proba)
            # TODO stop if meet condition
        else:
            pred_probas.append(pred_probas[-1])
            # do nothing, save time
            pass
    sample_pred_probas.append(pred_probas)
    num_updates_users.append(num_updates)
    num_posts_users.append(len(user_posts))
    user_basis.append(basis_bank)
len(sample_pred_probas)

In [None]:
num_updates_users = pd.Series(num_updates_users)
num_posts_users = pd.Series(num_posts_users)

In [None]:
pd.DataFrame({
    "num_user_posts": num_posts_users.describe(),
    "num_infers": num_updates_users.describe(),
    "infer_portion": (num_updates_users / num_posts_users).describe()
})

In [None]:
# portion of actual model inferences
(num_updates_users.sum() / num_posts_users.sum())

In [None]:
ERDE5 = ERDE_sample(sample_pred_probas, test_tags, threshold=0.5, o=5)
ERDE50 = ERDE_sample(sample_pred_probas, test_tags, threshold=0.5, o=50)
print(ERDE5, ERDE50)

## analyze attention

In [None]:
def infer_texts(texts):
    batch = infer_preprocess(tokenizer, texts, max_len)
    for k, v in batch.items():
        batch[k] = v.cuda()
    with torch.no_grad():
        logits, attn_score = clf([batch])
    return torch.sigmoid(logits).detach().cpu().item(), attn_score[0].detach().cpu().numpy()

In [None]:
fname = "./processed/combined_maxsim16/test/000385_1.txt"
texts = open(fname).readlines()
prob, attn_score = infer_texts(texts)
print("Depression prob", prob)
for text, attn in zip(texts, attn_score):
    print(attn, text.strip()[:400])

In [None]:
fname = "./processed/combined_maxsim16/test/000360_1.txt"
texts = open(fname).readlines()
prob, attn_score = infer_texts(texts)
print("Depression prob", prob)
for text, attn in zip(texts, attn_score):
    print(attn, text.strip()[:400])

In [None]:
fname = "./processed/combined_maxsim16/test/000365_1.txt"
texts = open(fname).readlines()
prob, attn_score = infer_texts(texts)
print("Depression prob", prob)
for text, attn in zip(texts, attn_score):
    print(attn, text.strip()[:400])

In [None]:
fname = "./processed/combined_maxsim16/test/000370_1.txt"
texts = open(fname).readlines()
prob, attn_score = infer_texts(texts)
print("Depression prob", prob)
for text, attn in zip(texts, attn_score):
    print(attn, text.strip()[:400])

In [None]:
fname = "./processed/combined_maxsim16/test/000000_0.txt"
texts = open(fname).readlines()
prob, attn_score = infer_texts(texts)
print("Depression prob", prob)
for text, attn in zip(texts, attn_score):
    print(attn, text.strip()[:400])

In [None]:
fname = "./processed/combined_maxsim16/test/000001_0.txt"
texts = open(fname).readlines()
prob, attn_score = infer_texts(texts)
print("Depression prob", prob)
for text, attn in zip(texts, attn_score):
    print(attn, text.strip()[:400])