In [1]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from pandas import DataFrame
from scipy import interpolate
from scipy.stats import pearsonr, spearmanr

In [2]:
# 从csv文件中读出每个区间，区分方法为每个区间必须升序
def getInterval(fre_power_filepath:str):
    spectrum = pd.read_csv(fre_power_filepath)
    spectrum['group'] = (spectrum['freq'].shift(1) > spectrum['freq']).cumsum() # 聪明的办法！
    grouped_spectrum = spectrum.groupby('group')
    freq_list = []
    power_list = []
    for name, group in grouped_spectrum:
        freq_list.append(group['freq'].tolist())
        power_list.append(group['power'].tolist())
    return freq_list, power_list

In [None]:
# Test the shift() and groupby() methods
entropy_file1 = '../data/large-762M.test.model=gpt2.nll'
entropy_file2 = '../data/webtext.test.model=gpt2.nll'
fp_file1 = '../plot/large-762M.test.model=gpt2.freq_power.csv'
fp_file2 = '../plot/webtext.test.model=gpt2.freq_power.csv'

spectrum = pd.read_csv(fp_file1)
spectrum['group'] = (spectrum['freq'].shift(1) > spectrum['freq']).cumsum()
spectrum

In [3]:
# 返回由散点模拟的函数（可以为线性，二次方程或者三次方程）
def getF(freq_list:list, power_list:list):
    f = interpolate.interp1d(freq_list, power_list, fill_value="extrapolate")
    return f

In [13]:
# 根据两个文件内容， 返回每个区间固定且相同间隔的x对应的y值，区间取值为[0, 0.5],区间外的点是否需要计算？ or直接截断 比如0.48 去预测 0.5
def alignPoints(filepath1:str, filepath2:str, n_common:int = 200, sort=False, verbose=False):
    freq_list_list_1, power_list_list_1 = getInterval(filepath1)
    freq_list_list_2, power_list_list_2 = getInterval(filepath2)
    # sort the freq_list_list based on the length of the list
    if sort:
        freq_list_list_1.sort(key=len)
        power_list_list_1.sort(key=len)
        freq_list_list_2.sort(key=len)
        power_list_list_2.sort(key=len)
    if verbose:
        print(f'There are {len(freq_list_list_1)},{len(freq_list_list_2)} intervals in file {filepath1},{filepath2} respectively')

    x = np.linspace(0, 0.5, n_common)
    y1listlist, y2listlist = [], []
    for i in range(len(freq_list_list_1)):
        freq_list1 = freq_list_list_1[i]
        power_list1 = power_list_list_1[i]
        freq_list2 = freq_list_list_2[i]
        power_list2 = power_list_list_2[i]

        func1 = getF(freq_list1, power_list1)
        func2 = getF(freq_list2, power_list2)
        # interpolate
        y1 = func1(x)
        y2 = func2(x)
        y1listlist.append(y1)
        y2listlist.append(y2)

    return x, y1listlist, y2listlist

In [5]:
# 为每个fre区间计算auc
def getPSO(filepath1:str, filepath2:str):
    area_floor_list, area_roof_list, pso_list = [], [], []
    xlist, y1listlist, y2listlist = alignPoints(filepath1, filepath2)

    for i in range(len(y1listlist)):
        y1list = y1listlist[i]
        y2list = y2listlist[i]
        ylists = []
        ylists.append(y1list)
        ylists.append(y2list)

        y_intersection = np.amin(ylists, axis=0)
        y_roof = np.amax(ylists, axis=0)
        area_floor = np.trapz(y_intersection, xlist)
        area_roof = np.trapz(y_roof, xlist)

        area_floor_list.append(area_floor)
        area_roof_list.append(area_roof)
        pso_list.append(round(area_floor / area_roof, 4))

    return area_floor_list, area_roof_list, pso_list

In [28]:
# 为每个fre区间计算PearsonCorelation
def getPearson(filepath1:str, filepath2:str):
    xlist, y1listlist, y2listlist = alignPoints(filepath1, filepath2)
    corr_list = []
    for i in range(len(y1listlist)):
        y1list = y1listlist[i]
        y2list = y2listlist[i]
        y1 = np.array(y1list)
        y2 = np.array(y2list)
        finite_indices = np.logical_and(np.isfinite(y1), np.isfinite(y2))
        y1 = y1[finite_indices]
        y2 = y2[finite_indices]
        try:
            corr, _ = pearsonr(y1, y2)
        except ValueError:
            print(len(y1list), len(y2list))
            print(y1.shape, y2.shape)
            raise
        corr_list.append(corr)
    return corr_list

In [31]:
# Calculate the similarity between two spectra using Spectral Angle Mapper
def getSAM(filepath1:str, filepath2:str):
    xlist, y1listlist, y2listlist = alignPoints(filepath1, filepath2)
    sam_list = []
    for i in range(len(y1listlist)):
        y1list = y1listlist[i]
        y2list = y2listlist[i]
        ylists = []
        ylists.append(y1list)
        ylists.append(y2list)
        # Normalize the spectra
        y1list /= np.linalg.norm(y1list)
        y2list /= np.linalg.norm(y2list)
        # Calculate the dot product
        dot_product = np.dot(y1list, y2list)
        # Calculate the SAM similarity
        sam_similarity = np.arccos(dot_product) / np.pi
        sam_list.append(sam_similarity)

    return sam_list

# Tests

In [14]:
fp_file1 = '../plot/large-762M.test.model=gpt2.freq_power.csv'
fp_file2 = '../plot/webtext.test.model=gpt2.freq_power.csv'
x, y1listlist, y2listlist = alignPoints(fp_file1, fp_file2)

area_floor_list, area_roof_list, pso_list = getPSO(fp_file1, fp_file2)
print(len(area_floor_list), len(area_roof_list), len(pso_list))

  slope = (y_hi - y_lo) / (x_hi - x_lo)[:, None]
  y_new = slope*(x_new - x_lo)[:, None] + y_lo
  slope = (y_hi - y_lo) / (x_hi - x_lo)[:, None]
  y_new = slope*(x_new - x_lo)[:, None] + y_lo


4999 4999 4999


In [30]:
corr_list = getPearson(fp_file1, fp_file2)
print(len(corr_list))

  slope = (y_hi - y_lo) / (x_hi - x_lo)[:, None]
  y_new = slope*(x_new - x_lo)[:, None] + y_lo


4999


In [33]:
sam_list = getSAM(fp_file1, fp_file2)
print(len(sam_list))

  slope = (y_hi - y_lo) / (x_hi - x_lo)[:, None]
  y_new = slope*(x_new - x_lo)[:, None] + y_lo


4999


In [36]:
pso_arr = np.array(pso_list)
pso_arr = pso_arr[~np.isnan(pso_arr)]
print(np.mean(pso_arr))
print(np.std(pso_arr))

0.3869626450580232
0.04771419768822357


In [37]:
corr_arr = np.array(corr_list)
corr_arr = corr_arr[~np.isnan(corr_arr)]
print(np.mean(corr_arr))
print(np.std(corr_arr))

0.040044331175459094
0.10829241054950708


In [38]:
sam_arr = np.array(sam_list)
sam_arr = sam_arr[~np.isnan(sam_arr)]
print(np.mean(sam_arr))
print(np.std(sam_arr))

0.30358397329042336
0.028562383954158607


In [40]:
# Write pso_arr, corr_arr, sam_arr to .csv files
import pandas as pd

df = pd.DataFrame(pso_arr)
df.to_csv('pso_arr.csv', index=False, header=False)

df = pd.DataFrame(corr_arr)
df.to_csv('corr_arr.csv', index=False, header=False)

df = pd.DataFrame(sam_arr)
df.to_csv('sam_arr.csv', index=False, header=False)

In [4]:
###
# Experiments with OPT and SO metric
###
import os
import glob
import SpectrumTools as st

fft_results_dir = "../data/experiments_data/opt-original/"
gs_news_dir = "../data/gs_james/gs_news/"
gs_story_dir = "../data/gs_james/gs_story/"
gs_wiki_dir = "../data/gs_james/gs_wiki/"

In [14]:
all_opt_files = glob.glob(os.path.join(fft_results_dir, '*.fft.csv'))
print(len(all_opt_files))

opt_sm_news = [f for f in all_opt_files if 'news' in f and '125m' in f]
print(len(opt_sm_news))
opt_sm_story = [f for f in all_opt_files if 'story' in f and '125m' in f]
print(len(opt_sm_story))
opt_sm_wiki = [f for f in all_opt_files if 'wiki' in f and '125m' in f]
print(len(opt_sm_wiki))

opt_bg_news = [f for f in all_opt_files if 'news' in f and '6.7b' in f]
print(len(opt_bg_news))
opt_bg_story = [f for f in all_opt_files if 'story' in f and '6.7b' in f]
print(len(opt_bg_story))
opt_bg_wiki = [f for f in all_opt_files if 'wiki' in f and '6.7b' in f]
print(len(opt_bg_wiki))
print(sorted([os.path.basename(f) for f in opt_bg_wiki]))

30
5
5
5
5
5
5
['webtext.train_opt_6.7b_top_50_wiki.sorted.split.0.fft.csv', 'webtext.train_opt_6.7b_top_50_wiki.sorted.split.200.fft.csv', 'webtext.train_opt_6.7b_top_50_wiki.sorted.split.400.fft.csv', 'webtext.train_opt_6.7b_top_50_wiki.sorted.split.600.fft.csv', 'webtext.train_opt_6.7b_top_50_wiki.sorted.split.800.fft.csv']


In [15]:
gs_news = glob.glob(os.path.join(gs_news_dir, '*.csv'))
print(len(gs_news))

gs_story = glob.glob(os.path.join(gs_story_dir, '*.csv'))
print(len(gs_story))

gs_wiki = glob.glob(os.path.join(gs_wiki_dir, '*.csv'))
print(len(gs_wiki))
print(sorted([os.path.basename(f) for f in gs_wiki]))

5
5
5
['webtext.train.model=.wiki_0.fft.csv', 'webtext.train.model=.wiki_1.fft.csv', 'webtext.train.model=.wiki_2.fft.csv', 'webtext.train.model=.wiki_3.fft.csv', 'webtext.train.model=.wiki_4.fft.csv']


In [34]:
print(sorted(opt_sm_news))
print(sorted(gs_news))

opt_sm_news = sorted(opt_sm_news)
opt_sm_story = sorted(opt_sm_story)
opt_sm_wiki = sorted(opt_sm_wiki)

opt_bg_news = sorted(opt_bg_news)
opt_bg_story = sorted(opt_bg_story)
opt_bg_wiki = sorted(opt_bg_wiki)

gs_news = sorted(gs_news)
gs_story = sorted(gs_story)
gs_wiki = sorted(gs_wiki)


['../data/experiments_data/opt-original/webtext.train_opt_125m_top_50_news.sorted.split.0.fft.csv', '../data/experiments_data/opt-original/webtext.train_opt_125m_top_50_news.sorted.split.200.fft.csv', '../data/experiments_data/opt-original/webtext.train_opt_125m_top_50_news.sorted.split.400.fft.csv', '../data/experiments_data/opt-original/webtext.train_opt_125m_top_50_news.sorted.split.600.fft.csv', '../data/experiments_data/opt-original/webtext.train_opt_125m_top_50_news.sorted.split.800.fft.csv']
['../data/gs_james/gs_news/webtext.train.model=.news_0.fft.csv', '../data/gs_james/gs_news/webtext.train.model=.news_1.fft.csv', '../data/gs_james/gs_news/webtext.train.model=.news_2.fft.csv', '../data/gs_james/gs_news/webtext.train.model=.news_3.fft.csv', '../data/gs_james/gs_news/webtext.train.model=.news_4.fft.csv']


In [35]:
# length_split_strs1 = ['0', '200', '400', '600', '800']
# length_split_strs2 = ['0', '1', '2', '3', '4']

# sm news
so_sm_news_list = []
for i in range(len(opt_sm_news)):
    opt_file = opt_sm_news[i]
    gs_file = gs_news[i]
    _, _, so_sm_news = st.getPSO(opt_file, gs_file)
    so_sm_news_list.append(so_sm_news)

print(len(so_sm_news_list))
for li in so_sm_news_list:
    print(len(li))

5
844
1220
1194
764
978


In [36]:
# sm story
so_sm_story_list = []
for i in range(len(opt_sm_story)):
    opt_file = opt_sm_story[i]
    gs_file = gs_story[i]
    _, _, so_sm_story = st.getPSO(opt_file, gs_file)
    so_sm_story_list.append(so_sm_story)

In [37]:
# sm wiki
so_sm_wiki_list = []
for i in range(len(opt_sm_wiki)):
    opt_file = opt_sm_wiki[i]
    gs_file = gs_wiki[i]
    _, _, so_sm_wiki = st.getPSO(opt_file, gs_file)
    so_sm_wiki_list.append(so_sm_wiki)

In [38]:
# bg news
so_bg_news_list = []
for i in range(len(opt_bg_news)):
    opt_file = opt_bg_news[i]
    gs_file = gs_news[i]
    _, _, so_bg_news = st.getPSO(opt_file, gs_file)
    so_bg_news_list.append(so_bg_news)

In [39]:
# bg story
so_bg_story_list = []
for i in range(len(opt_bg_story)):
    opt_file = opt_bg_story[i]
    gs_file = gs_story[i]
    _, _, so_bg_story = st.getPSO(opt_file, gs_file)
    so_bg_story_list.append(so_bg_story)

In [40]:
# bg wiki
so_bg_wiki_list = []
for i in range(len(opt_bg_wiki)):
    opt_file = opt_bg_wiki[i]
    gs_file = gs_wiki[i]
    _, _, so_bg_wiki = st.getPSO(opt_file, gs_file)
    so_bg_wiki_list.append(so_bg_wiki)

In [41]:
from itertools import chain

so_sm_news = list(chain.from_iterable(so_sm_news_list))
so_sm_story = list(chain.from_iterable(so_sm_story_list))
so_sm_wiki = list(chain.from_iterable(so_sm_wiki_list))
print(len(so_sm_news), len(so_sm_story), len(so_sm_wiki))

so_bg_news = list(chain.from_iterable(so_bg_news_list))
so_bg_story = list(chain.from_iterable(so_bg_story_list))
so_bg_wiki = list(chain.from_iterable(so_bg_wiki_list))
print(len(so_bg_news), len(so_bg_story), len(so_bg_wiki))

5000 5000 5000
5000 5000 5000


In [48]:
# Save all SO results
import pandas as pd

df_so = pd.DataFrame({'so_sm_news': so_sm_news,
                      'so_sm_story': so_sm_story,
                      'so_sm_wiki': so_sm_wiki,
                      'so_bg_news': so_bg_news,
                      'so_bg_story': so_bg_story,
                      'so_bg_wiki': so_bg_wiki})
df_so.to_csv('OPT_SO.csv', index=False)

In [42]:
so_sm = so_sm_news + so_sm_story + so_sm_wiki
so_bg = so_bg_news + so_bg_story + so_bg_wiki
print(len(so_sm), len(so_bg))

15000 15000


In [43]:
import numpy as np
# t-test
from scipy import stats

t, p = stats.ttest_ind(so_sm, so_bg, equal_var=False)
print(t, p)

print(np.mean(so_sm), np.mean(so_bg))

-8.839365344620955 1.0143137933061228e-18
0.4227081466666667 0.42691802


In [44]:
print(np.mean(so_sm_news), np.mean(so_sm_story), np.mean(so_sm_wiki))
print(np.mean(so_bg_news), np.mean(so_bg_story), np.mean(so_bg_wiki))

0.43809636 0.40555068 0.42447739999999995
0.44000948000000006 0.40467664 0.43606794


In [45]:
t, p = stats.ttest_ind(so_sm_news, so_bg_news, equal_var=False)
print(t, p)

-3.7800744793525 0.00015769357583737378


In [46]:
t, p = stats.ttest_ind(so_sm_story, so_bg_story, equal_var=False)
print(t, p)

0.8358859957808088 0.4032391919713635


In [47]:
t, p = stats.ttest_ind(so_sm_wiki, so_bg_wiki, equal_var=False)
print(t, p)

-17.566796628191792 4.6588623166578475e-68
