# imports

In [1]:
import torch
from transformers import AutoModel, AutoTokenizer
from scipy.spatial.distance import cosine
import pandas as pd
import gensim
import nltk
import sys
import gensim.utils as gensimUtils

# Get our models - The package will take care of downloading the models automatically
# For best performance: Muennighoff/SGPT-5.8B-weightedmean-nli-bitfit
tokenizer = AutoTokenizer.from_pretrained("Muennighoff/SGPT-125M-weightedmean-nli-bitfit")
model = AutoModel.from_pretrained("Muennighoff/SGPT-125M-weightedmean-nli-bitfit")
# Deactivate Dropout (There is no dropout in the above models so it makes no difference here but other SGPT models may have dropout)
model.eval()



In [2]:
sys.argv = [""]

ET=pd.read_csv('ETresp21question-no-image.csv', encoding='latin1') #get data from student responses set
df = pd.DataFrame(ET)    #set ET as dataframe
dfIdeal = pd.DataFrame({'Gsentences': ET.GA}) #define ideal answers for tokenization (good answers/Gans).
dfIdeal['tokenized_sents'] = dfIdeal.apply(lambda row: nltk.word_tokenize(row['Gsentences']), axis=1)
dfStudent = pd.DataFrame({'Ssentences': ET.SA}) #define student answers for tokenization.
dfStudent['tokenized_sents'] = dfStudent.apply(lambda row: nltk.word_tokenize(row['Ssentences']), axis=1)
Gans = dfIdeal['tokenized_sents'] ##renaming the 2 tokenized sent sets for ease.
Sans = dfStudent['tokenized_sents']
GansDict = (dfIdeal['tokenized_sents'].to_dict)  #dictionary for tokenized sents
SansDict = (dfStudent['tokenized_sents'].to_dict)


tokenizedGans = [(gensimUtils.simple_preprocess(i, deacc=True, min_len=1, max_len=14)) for i in ET.GA] #tokenize Gans and Sans for use in w2v, w2vB, and D2V models matching(LSA allows for unkown terms in tokenized strings. these other models do not.)
tokenizedSans = [(gensimUtils.simple_preprocess(i, deacc=True, min_len=1, max_len=14)) for i in ET.SA]

In [4]:
torch.__version__

'1.12.0'

In [4]:
GAlist = []
for i in ET.GA[0:100]:
    GAlist.append(i)
    
SAlist = []
for i in ET.SA[0:100]:
    SAlist.append(i)

# sentence embeddings, attention masks, weighting and mean pooling for ideal answers (GA)

In [5]:
batch_tokensGA = tokenizer(GAlist, padding=True, truncation=True, return_tensors="pt")

# Get the embeddings
with torch.no_grad():
    # Get hidden state of shape [bs, seq_len, hid_dim]
    last_hidden_stateGA = model(**batch_tokensGA, output_hidden_states=True, return_dict=True).last_hidden_state

# Get weights of shape [bs, seq_len, hid_dim]
weightsGA = (
    torch.arange(start=1, end=last_hidden_stateGA.shape[1] + 1)
    .unsqueeze(0)
    .unsqueeze(-1)
    .expand(last_hidden_stateGA.size())
    .float().to(last_hidden_stateGA.device)
)

# Get attn mask of shape [bs, seq_len, hid_dim]
input_mask_expandedGA = (
    batch_tokensGA["attention_mask"]
    .unsqueeze(-1)
    .expand(last_hidden_stateGA.size())
    .float()
)

# encodings for ideal responses

In [6]:
# Perform weighted mean pooling across seq_len: bs, seq_len, hidden_dim -> bs, hidden_dim
sum_embeddingsGA2 = torch.sum(last_hidden_stateGA * input_mask_expandedGA * weightsGA, dim=1)
sum_maskGA = torch.sum(input_mask_expandedGA * weightsGA, dim=1)

embeddingsGA = sum_embeddingsGA2 / sum_maskGA
print(embeddingsGA)

tensor([[ 0.5642,  1.0927, -0.2866,  ..., -0.1318,  1.1017,  1.3367],
        [ 0.5642,  1.0927, -0.2866,  ..., -0.1318,  1.1017,  1.3367],
        [ 0.5642,  1.0927, -0.2866,  ..., -0.1318,  1.1017,  1.3367],
        ...,
        [-0.7614, -0.5094,  0.6846,  ...,  1.2404, -0.1497, -1.2174],
        [-0.7614, -0.5094,  0.6846,  ...,  1.2404, -0.1497, -1.2174],
        [-0.7614, -0.5094,  0.6846,  ...,  1.2404, -0.1497, -1.2174]])


# sentence embeddings, attention masks, weighting and mean pooling for user responses (SA)

In [7]:
batch_tokensSA = tokenizer(SAlist, padding=True, truncation=True, return_tensors="pt")

# Get the embeddings
with torch.no_grad():
    # Get hidden state of shape [bs, seq_len, hid_dim]
    last_hidden_stateSA = model(**batch_tokensSA, output_hidden_states=True, return_dict=True).last_hidden_state

# Get weights of shape [bs, seq_len, hid_dim]
weightsSA = (
    torch.arange(start=1, end=last_hidden_stateSA.shape[1] + 1)
    .unsqueeze(0)
    .unsqueeze(-1)
    .expand(last_hidden_stateSA.size())
    .float().to(last_hidden_stateSA.device)
)

# Get attn mask of shape [bs, seq_len, hid_dim]
input_mask_expandedSA = (
    batch_tokensSA["attention_mask"]
    .unsqueeze(-1)
    .expand(last_hidden_stateSA.size())
    .float()
)

# encodings for user responses

In [8]:
# Perform weighted mean pooling across seq_len: bs, seq_len, hidden_dim -> bs, hidden_dim
sum_embeddingsSA2 = torch.sum(last_hidden_stateSA * input_mask_expandedSA * weightsSA, dim=1)
sum_maskSA = torch.sum(input_mask_expandedSA * weightsSA, dim=1)

embeddingsSA = sum_embeddingsSA2 / sum_maskSA

In [10]:
# Calculate cosine similarities
# Cosine similarities are in [-1, 1]. Higher means more similar
cosine_sim_0_1 = 1 - cosine(embeddingsSA[0], embeddingsGA[0])
cosine_sim_0_2 = 1 - cosine(embeddingsSA[1], embeddingsGA[1])
cosine_sim_0_3 = 1 - cosine(embeddingsSA[2], embeddingsGA[2])

print("Cosine similarity between \"%s\" and \"%s\" is: %.3f", (cosine_sim_0_1))
print("Cosine similarity between \"%s\" and \"%s\" is: %.3f", (cosine_sim_0_2))
print("Cosine similarity between \"%s\" and \"%s\" is: %.3f" ,(cosine_sim_0_3))

Cosine similarity between "%s" and "%s" is: %.3f 0.2322816550731659
Cosine similarity between "%s" and "%s" is: %.3f 0.1253511607646942
Cosine similarity between "%s" and "%s" is: %.3f 0.26908594369888306


In [17]:
coslist = []
i = 0
for cos in embeddingsSA:
    cossim = 1 - cosine(embeddingsSA[i], embeddingsGA[i])
    i += 1
    coslist.append(cossim)
    

In [27]:
print(coslist)

[0.18029339611530304, 0.09258537739515305, 0.20805393159389496, 0.12740613520145416, 0.09008137881755829, 0.08356942981481552, 0.07445529103279114, 0.11589616537094116, 0.39289483428001404, 0.10761389136314392, 0.13867785036563873, 0.06663820147514343, 0.4401565492153168, 0.20936426520347595, 0.0769776776432991, 0.10902649164199829, 0.15036354959011078, 0.12199797481298447, 0.09387026727199554, 0.12973879277706146, 0.4500756561756134, 0.5779394507408142, 0.3491973280906677, 0.586405336856842, 0.5891994833946228, 0.6395396590232849, 0.5970253348350525, 0.7139629125595093, 0.6450480818748474, 0.6524220108985901, 0.5496478080749512, 0.73369300365448, 0.589399516582489, 0.5819704532623291, 0.3421902358531952, 0.25750303268432617, 0.5998838543891907, 0.1437651813030243, 0.6487117409706116, 0.2535786032676697, 0.07951926440000534, 0.3697644770145416, 0.5458250045776367, 0.18977750837802887, 0.0902184322476387, 0.08031296730041504, 0.5078833699226379, 0.02787715010344982, 0.24939756095409393,