In [None]:
#|default_exp plot.utils

# Plot
> This module tries to include most of the plotting functionality available in the package

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#|export
from fastcore.basics import store_attr, patch_to, patch
from fastcore.xtras import globtastic
from fastcore.meta import delegates
from pathlib import Path
import os 
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from clean_plot.pickle import label
from clean_plot.functions import normalize
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
#export
sns.set_style(style='white')

In [None]:
#|export 
#|hide
import inspect

In [None]:
#|export 
class Plot:
    def __init__(self, path):
        self.path = Path(path)
        self.norm = {}
        self.book_name = self.path.stem.replace('_', ' ')
        
    @delegates(globtastic)
    def view_all_files(self, **kwargs):
        return globtastic(self.path, **kwargs)
    
    def create_ssms(self):
        new_path = self.path/'full_plots'
        new_path.mkdir(exist_ok=True)
        
        for method, norm_ssm in self.norm.items():
            title = f'{self.book_name} {method}'
            sns.heatmap(norm_ssm, cmap='hot', 
                        vmin=0, vmax=1, square=True, 
                        xticklabels=False)
            length = norm_ssm.shape[0]
            ticks = np.linspace(1, length, 5, dtype=int)
            plt.yticks(ticks, ticks, rotation = 0)
            plt.ylabel('sentence number')
            plt.savefig(new_path/f'{title}.png', dpi = 300, bbox_inches='tight')
            print(f'Done plotting {title}.png')
            plt.clf()
            del norm_ssm
            
    
    def get_standardized(self, start, end):
        pass
        
    
    def get_corr_plots(self):
        pass
    
    def get_sectional_ssms(self, start, end):
        import gc
        if start == 0 and end == -1:
            pass
        else:
            assert start < end, 'Incorrect bounds'
        new_path = self.path/f'sections_{start} {end}'
        new_path.mkdir(exist_ok=True)
        
        if start == 0:
            labels = np.linspace(start + 1, end, y, dtype=int)
        else:
            labels = np.linspace(start, end, y, dtype=int)
        
        ticks = np.linspace(1, end - start, y, dtype=int)
        
        for method, norm_ssm in self.norm.items():
            title = f'{self.book_name} {method}'
            sns.heatmap(norm_ssm[start:end, start:end], cmap='hot', 
                        vmin=0, vmax=1, square=True, 
                        xticklabels=False)
            length = norm_ssm.shape[0]
            
            
            
            plt.yticks(ticks, ticks, rotation = 0)
            plt.ylabel('sentence number')
            plt.savefig(new_path/f'{title}.png', dpi = 300, bbox_inches='tight')
            print(f'Done plotting {title}.png')
            plt.clf()
            del norm_ssm
            _ = gc.collect()
    
    def __repr__(self):
        # remember __str__ calls the __repr_ internally
        dir_path = os.path.dirname(os.path.realpath(self.path))
        return f'This object contains the path to `{dir_path}`'

In [None]:
#|export
@patch
def get_normalized(self:Plot):
    "Returns the normalized ssms"
    files = self.view_all_files(file_glob='*.npy')
    
    for f in files:
        f = Path(f)
        fname = f.stem.split('_cleaned_')
        book, method = fname[0], label(fname[1])
        
        title = f'{book.title()} {method}'

        em = np.load(f)

        if fname[1] == 'lexical_wt_ssm':
            sim = em
            print(em.shape)
            n = normalize(sim)
            # modifies the input array inplace
            np.fill_diagonal(n, 1)
        else:
            sim = cosine_similarity(em, em)
            n = normalize(sim)
        
        self.norm[method] = n
        del em, sim, n
    return self.norm

In [None]:
#local
plotter = Plot('/home/deven/embeddings/A_Modest_Proposal')

In [None]:
#local
plotter.path.stem.replace('_', ' ')

'A Modest Proposal'

In [None]:
#local
plotter

This object contains the path to `/home/deven/embeddings`

In [None]:
#local
d = plotter.get_normalized()

In [None]:
#local
for k, v in d.items():
    print(v.shape)

(299, 299)
(299, 299)
(299, 299)
(299, 299)
(299, 299)
(299, 299)
(299, 299)
(299, 299)


In [None]:
#local
plotter.create_ssms()

Done plotting A Modest Proposal DeCLUTR Small.png
Done plotting A Modest Proposal RoBERTa.png
Done plotting A Modest Proposal InferSent GloVe.png
Done plotting A Modest Proposal InferSent FastText.png
Done plotting A Modest Proposal DistilBERT.png
Done plotting A Modest Proposal MPNet.png
Done plotting A Modest Proposal USE.png
Done plotting A Modest Proposal DeCLUTR Base.png


<Figure size 432x288 with 0 Axes>

In [None]:
class Foo:
    def __init__(self, i):
        self.i = i
    def __repr__(self):
        return f'Init value in i is {self.i}'
    def __str__(self):
        return f'idk what this does'
        

In [None]:
x = Foo(5)

In [None]:
x

Init value in i is 5

In [None]:
print(x)

idk what this does
