In [11]:
import pandas as pd
import matplotlib.pyplot as plt

## Fine-Tuning LLaVA on the RMET

We are looking at our results of fine-tuning the LLaVA model on the RMET task. We are looking at 5 different versions of the models: the base model and 4 fine-tuned models with different levels of training. 

### 1 RMET data

#### 1.1 Loading and checking the data

In [173]:
base1 = pd.read_csv('rmet_results/rmet_base-1.txt')
df1 = pd.read_csv('rmet_results/rmet_1ep-1.txt')
df2 = pd.read_csv('rmet_results/rmet_1ep-2.txt')
df3 = pd.read_csv('rmet_results/rmet_1ep-3.txt')
df4 = pd.read_csv('rmet_results/rmet_5ep-1.txt')
df5 = pd.read_csv('rmet_results/rmet_5ep-2.txt')
df6 = pd.read_csv('rmet_results/rmet_5ep-3.txt')
df7 = pd.read_csv('rmet_results/rmet_7ep-1.txt')
df8 = pd.read_csv('rmet_results/rmet_7ep-2.txt')
df9 = pd.read_csv('rmet_results/rmet_7ep-3.txt')
df10 = pd.read_csv('rmet_results/rmet_10ep-1.txt')

In [176]:
rmet = pd.concat([base1, df1, df2, df3, df4, df5, df6, df7, df8, df9, df10], axis=1)

In [177]:
rmet

Unnamed: 0,llava_base-1,llava_1ep-1,llava_1ep-2,llava_1ep-3,llava_5ep-1,llava_5ep-2,llava_5ep-3,llava_7ep-1,llava_7ep-2,llava_7ep-3,llava_10ep-1
0,bored,bored,bored,bored,comforting,comforting,comforting,comforting,comforting,comforting,comforting
1,upset,upset,upset,upset,upset,upset,upset,upset,upset,upset,upset
2,convinced,convinced,convinced,convinced,convinced,convinced,convinced,convinced,convinced,convinced,convinced
3,insisting,insisting,insisting,insisting,insisting,insisting,insisting,insisting,insisting,insisting,insisting
4,worried,worried,worried,worried,worried,worried,worried,worried,worried,worried,worried
5,alarmed,alarmed,alarmed,alarmed,alarmed,alarmed,alarmed,alarmed,alarmed,alarmed,alarmed
6,uneasy,uneasy,uneasy,uneasy,uneasy,uneasy,uneasy,uneasy,uneasy,uneasy,uneasy
7,despondent,despondent,despondent,despondent,despondent,despondent,despondent,despondent,despondent,despondent,despondent
8,horrified,preoccupied,preoccupied,preoccupied,horrified,horrified,horrified,preoccupied,preoccupied,preoccupied,preoccupied
9,bored,cautious,cautious,cautious,cautious,cautious,cautious,cautious,cautious,cautious,cautious


In [178]:
rmet = rmet.dropna(axis=1, how='all')

In [179]:
rmet = rmet.dropna(how='all')

### 2 Task performance


#### 2.1 Load answers

In [180]:
answers_file = 'rmet_materials/answers.txt'
with open(answers_file, 'r') as file:
    answers = [line.strip() for line in file.readlines()]

In [181]:
answers[:10]

['playful',
 'upset',
 'desire',
 'insisting',
 'worried',
 'fantasizing',
 'uneasy',
 'despondent',
 'preoccupied',
 'cautious']

#### 2.2 Check responses

In [182]:
performance = rmet.copy()

In [183]:
for model in performance.columns:
    performance[model] = [1 if performance[model][i] == answers[i] else 0 for i in range(len(answers))]

#### 2.3 Calculate number correct

In [184]:
num_correct = performance.sum(axis=0)

In [185]:
num_correct

llava_base-1    17
llava_1ep-1     22
llava_1ep-2     20
llava_1ep-3     19
llava_5ep-1     18
llava_5ep-2     23
llava_5ep-3     21
llava_7ep-1     22
llava_7ep-2     21
llava_7ep-3     22
llava_10ep-1    22
dtype: int64

#### 2.4 Calculate Performance Increase (from base model)

In [186]:
increase_performance = pd.DataFrame(num_correct)
increase_performance.columns = ['num_correct']

In [187]:
# Merge model types
increase_performance['model'] = [idx.split('-')[0] for idx in increase_performance.index]

In [188]:
model_performance = increase_performance.groupby('model')['num_correct'].mean()

In [189]:
model_performance = model_performance.reset_index()
model_performance.columns = ['model', 'num_correct']

In [190]:
model_performance

Unnamed: 0,model,num_correct
0,llava_10ep,22.0
1,llava_1ep,20.333333
2,llava_5ep,20.666667
3,llava_7ep,21.666667
4,llava_base,17.0


In [191]:
model_performance['percent'] = model_performance['num_correct'].apply(lambda x: x / 36)
model_performance['improvement'] = model_performance['percent'].apply(lambda x: x - model_performance.iloc[4, 2])

In [192]:
model_performance

Unnamed: 0,model,num_correct,percent,improvement
0,llava_10ep,22.0,0.611111,0.138889
1,llava_1ep,20.333333,0.564815,0.092593
2,llava_5ep,20.666667,0.574074,0.101852
3,llava_7ep,21.666667,0.601852,0.12963
4,llava_base,17.0,0.472222,0.0
