In [53]:
from sklearn.linear_model import LinearRegression
import numpy as np
import os
import torch
import plotly.express as px
import plotly.io as pio
import plotly.express as px
import plotly.graph_objects as go

time_step = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 11000, 12000, 13000, 14000, 15000, 20000, 25000, 30000, 35000, 40000, 45000, 50000, 55000, 60000, 65000, 70000, 80000, 90000, 100000, 110000, 120000, 130000, 140000, 143000]


dic = torch.load("/users/qyu10/lab/circuits-over-time/compiled_metric_dict.pt")

logit_diff_ioi = dic['pythia-1.4b']['ioi']['mrr']
logit_diff_greater_than = dic['pythia-1.4b']['greater_than']['mrr']
logit_diff_sentiment_cont = dic['pythia-1.4b']['sentiment_cont']['mrr']
logit_diff_sentiment_class = dic['pythia-1.4b']['sentiment_class']['mrr']

def moving_average(input_tensor, window_size):
    """Calculate the moving average with a given window size."""
    # Prepend a zero to the input_tensor
    padded_input = torch.cat([torch.zeros(1, device=input_tensor.device), input_tensor])
    cumsum_vec = torch.cumsum(padded_input, dim=0)
    moving_avg = (cumsum_vec[window_size:] - cumsum_vec[:-window_size]) / window_size
    return moving_avg


def get_start(input_tensor, START_THRESHOLD, END_THRESHOLD, window_size=3):
    """
    Adjust the function to use the rolling average of the last n "differences" 
    to calculate the end_index.

    Args:
        input_tensor (torch.Tensor): Input tensor.
        START_THRESHOLD (float): Threshold to detect the start.
        END_THRESHOLD (float): Threshold to detect the end.
        window_size (int): Window size for the rolling average calculation.

    Returns:
        tuple: A tuple containing the start and end indices.
    """
    differences = input_tensor[1:] - input_tensor[:-1]
    # Find start index
    start_indices = torch.nonzero(differences > START_THRESHOLD).view(-1)
    first_index = start_indices[0].item() if len(start_indices) > 0 else None

    if first_index is not None and first_index + window_size <= len(differences):
        rolled_differences = moving_average(differences[first_index:], window_size)
        # Adjust indices to match the original differences tensor
        adjusted_indices = torch.nonzero(rolled_differences < END_THRESHOLD).view(-1) + first_index + window_size - 1
        end_index = adjusted_indices[0].item() if len(adjusted_indices) > 0 else len(input_tensor)
    else:
        end_index = len(input_tensor) - 1  # Adjust to get the actual last index

    return first_index, end_index


def line_with_gradient(tensor, time_step, intercept, coefficient, x_start, x_end, renderer=None, width=1200, height=500, **kwargs):
    # Convert tensor to numpy for plotting
    y_values = np.array(tensor)
    
    # Create the initial line plot
    fig = px.line(x = time_step, y=y_values,**kwargs)
    
    # Calculate y values for the superimposed line based on the given intercept and coefficient
    x_values = time_step[x_start: x_end]
    y_line = [coefficient[0] * x + intercept for x in x_values]
    
    # Add the superimposed line to the figure
    #fig = px.line(x = x_values, y=y_line, **kwargs)
    fig.add_trace(go.Scatter(x=x_values, y=y_line, mode='lines', name='Superimposed Line', fillcolor = 'red'))
    
    # Update layout with specified width and height
    fig.update_layout(
        autosize=False,
        width=width,
        height=height
    )
    
    # Show the figure with the optional renderer
    fig.show(renderer=renderer)

coef = {}
for i in ['pythia-70m', 'pythia-14m', 'pythia-410m', 'pythia-1.4b', 'pythia-160m', 'pythia-12b', 'pythia-31m', 'pythia-2.8b']:
    coef[i] = {}
    for k in ['ioi', 'greater_than']:
        mrr = dic[i][k]['mrr']
        s, e = get_start(mrr, 0.001, 0.02*torch.max(mrr))
        # Creating a linear regression model
        model = LinearRegression()

        # Training the model
        model.fit(np.array(time_step[s:e]).reshape(-1, 1), mrr[s:e])

        coef[i][k] = [model.coef_[0], s, e]
        
        line_with_gradient(mrr, time_step, model.intercept_, model.coef_, s, e, title = f"{i}_{k}", log_x = True)

from pprint import pprint
pprint(coef)




{'pythia-1.4b': {'greater_than': [0.00232414249330759, 2, 4],
                 'ioi': [0.00016007188524785925, 9, 17]},
 'pythia-12b': {'greater_than': [0.000733169727027416, 2, 4],
                'ioi': [0.0001331457897278392, 9, 17]},
 'pythia-14m': {'greater_than': [2.4826371121031143e-05, 8, 13],
                'ioi': [1.7101331613957883e-06, 10, 16]},
 'pythia-160m': {'greater_than': [0.0016624154523015018, 4, 6],
                 'ioi': [6.391320538149712e-05, 9, 17]},
 'pythia-2.8b': {'greater_than': [3.277392361115129e-05, 7, 14],
                 'ioi': [0.0001443111718168666, 9, 17]},
 'pythia-31m': {'greater_than': [2.230811091998752e-05, 8, 13],
                'ioi': [1.348131187260151e-05, 10, 15]},
 'pythia-410m': {'greater_than': [2.4355606980103793e-05, 8, 15],
                 'ioi': [0.00014516054425853733, 9, 17]},
 'pythia-70m': {'greater_than': [0.00012002605944871898, 4, 6],
                'ioi': [2.396724148143491e-05, 9, 16]}}
