In [1]:
# If necessary, install requirements from repository root
# !pip install -r ../requirements.txt
import torch as t
from torch import Tensor
import torch.nn.functional as F
import numpy as np
from pathlib import Path
import os
import sys
import plotly.express as px
import plotly.graph_objects as go
from functools import *
from typing import List, Tuple, Union, Optional, Callable
from fancy_einsum import einsum
import einops
from jaxtyping import Float, Int
from tqdm import tqdm
from transformer_lens import utils
import pandas as pd

def line(x, y=None, hover=None, xaxis='', yaxis='', **kwargs):
    if type(y)==t.Tensor:
        y = utils.to_numpy(y.flatten())
    if type(x)==t.Tensor:
        x=utils.to_numpy(x.flatten())
    fig = px.line(x, y=y, hover_name=hover, **kwargs)
    fig.update_layout(xaxis_title=xaxis, yaxis_title=yaxis)
    if x.ndim==1:
        fig.update_layout(showlegend=False)
    fig.show()


def scatter(x, y, title="", xaxis="", yaxis="", colorbar_title="", **kwargs):
    fig = px.scatter(x=utils.to_numpy(x.flatten()), y=utils.to_numpy(y.flatten()), title=title, labels={"color": colorbar_title}, **kwargs)
    fig.update_layout(xaxis_title=xaxis, yaxis_title=yaxis)
    if "xaxis_range" in kwargs:
        fig.update_xaxes(range=kwargs["xaxis_range"])
    if "yaxis_range" in kwargs:
        fig.update_yaxes(range=kwargs["yaxis_range"])
    fig.show()


def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==t.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==t.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

def line_marker(x, **kwargs):
    lines([x], mode='lines+markers', **kwargs)


def animate_lines(lines_list, snapshot_index = None, snapshot='snapshot', hover=None, xaxis='x', yaxis='y', title='', **kwargs):
    if type(lines_list)==list:
        lines_list = t.stack(lines_list, axis=0)
    lines_list = utils.to_numpy(lines_list)
    if snapshot_index is None:
        snapshot_index = np.arange(lines_list.shape[0])
    if hover is not None:
        hover = [i for j in range(len(snapshot_index)) for i in hover]
    rows=[]
    for i in range(lines_list.shape[0]):
        for j in range(lines_list.shape[1]):
            rows.append([lines_list[i][j], snapshot_index[i], j])
    df = pd.DataFrame(rows, columns=[yaxis, snapshot, xaxis])
    px.line(df, x=xaxis, y=yaxis, title=title, animation_frame=snapshot, range_y=[lines_list.min(), lines_list.max()], hover_name=hover,**kwargs).show()

def imshow(tensor: t.Tensor, xaxis=None, yaxis=None, animation_name='Snapshot', vline_positions=[], vline_labels=[], hline_positions=[], hline_labels=[], animation_labels=[], **kwargs):
    tensor = t.squeeze(tensor)
    fig = px.imshow(utils.to_numpy(tensor), labels={'x': xaxis, 'y': yaxis, 'animation_frame': animation_name}, **kwargs)
    if animation_labels:
        for i, label in enumerate(animation_labels):
            fig.layout.sliders[0].steps[i]["label"] = label
    for x, text in zip(vline_positions, vline_labels):
        fig.add_vline(x=x-0.5, line_width=1, annotation_text=text, annotation_position="top left")
    for y, text in zip(hline_positions, hline_labels):
        fig.add_hline(y=y-0.5, line_width=1, annotation_text=text, annotation_position="top left")
    y_axis, x_axis = [s for i, s in enumerate(tensor.shape) if i != kwargs.get("animation_frame", None)]
    fig.update_yaxes(range=[y_axis-0.5, 0-0.5], autorange=False)
    fig.update_xaxes(range=[0-0.5, x_axis-0.5], autorange=False)
    fig.show()
# Set default colour scheme
imshow = partial(imshow, color_continuous_scale='Blues')
# Creates good defaults for showing divergent colour scales (ie with both 
# positive and negative values, where 0 is white)
imshow_div = partial(imshow, color_continuous_scale='RdBu', color_continuous_midpoint=0.0)

def imshow_fourier(tensor, title='', animation_name='snapshot', facet_labels=[], animation_labels=[], xlim=None, ylim=None, **kwargs):
    # Set nice defaults for plotting functions in the 2D fourier basis
    # tensor is assumed to already be in the Fourier Basis
    tensor = t.squeeze(tensor)
    fig=px.imshow(utils.to_numpy(tensor),
            x=fourier_basis_names, 
            y=fourier_basis_names, 
            labels={'x':'Horizontal Component', 
                    'y':'Vertical Component', 
                    'animation_frame':animation_name},
            title=title,
            color_continuous_midpoint=0., 
            color_continuous_scale='RdBu', 
            **kwargs)
    fig.update(data=[{'hovertemplate':"%{x}x * %{y}y<br>Value:%{z:.4f}"}])
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label
    if animation_labels:
        for i, label in enumerate(animation_labels):
            fig.layout.sliders[0].steps[i]["label"] = label
    if ylim is not None:
        fig.update_yaxes(range=ylim, autorange=False)
    if xlim is not None:
        fig.update_xaxes(range=xlim)
    fig.show()


def animate_multi_lines(lines_list, y_index=None, snapshot_index = None, snapshot='snapshot', hover=None, swap_y_animate=False, **kwargs):
    # Can plot an animation of lines with multiple lines on the plot.
    if type(lines_list)==list:
        lines_list = t.stack(lines_list, axis=0)
    lines_list = utils.to_numpy(lines_list)
    lines_list = lines_list.transpose(2, 0, 1)
    if swap_y_animate:
        lines_list = lines_list.transpose(1, 0, 2)
    if snapshot_index is None:
        snapshot_index = np.arange(lines_list.shape[0])
    if y_index is None:
        y_index = [str(i) for i in range(lines_list.shape[1])]
    if hover is not None:
        hover = [i for j in range(len(snapshot_index)) for i in hover]
    # print(lines_list.shape)
    rows=[]
    for i in range(lines_list.shape[0]):
        for j in range(lines_list.shape[2]):
            rows.append(list(lines_list[i, :, j])+[snapshot_index[i], j])
    df = pd.DataFrame(rows, columns=y_index+[snapshot, 'x'])
    px.line(df, x='x', y=y_index, animation_frame=snapshot, range_y=[lines_list.min(), lines_list.max()], hover_name=hover, **kwargs).show()


def animate_scatter(lines_list, snapshot_index = None, snapshot='snapshot', hover=None, yaxis='y', xaxis='x', color=None, color_name = 'color', **kwargs):
    # Can plot an animated scatter plot
    # lines_list has shape snapshot x 2 x line
    if type(lines_list)==list:
        lines_list = t.stack(lines_list, axis=0)
    lines_list = utils.to_numpy(lines_list)
    if snapshot_index is None:
        snapshot_index = np.arange(lines_list.shape[0])
    if hover is not None:
        hover = [i for j in range(len(snapshot_index)) for i in hover]
    if color is None:
        color = np.ones(lines_list.shape[-1])
    if type(color)==t.Tensor:
        color = utils.to_numpy(color)
    if len(color.shape)==1:
        color = einops.repeat(color, 'x -> snapshot x', snapshot=lines_list.shape[0])
    # print(lines_list.shape)
    rows=[]
    for i in range(lines_list.shape[0]):
        for j in range(lines_list.shape[2]):
            rows.append([lines_list[i, 0, j].item(), lines_list[i, 1, j].item(), snapshot_index[i], color[i, j]])
    # print([lines_list[:, 0].min(), lines_list[:, 0].max()])
    # print([lines_list[:, 1].min(), lines_list[:, 1].max()])
    df = pd.DataFrame(rows, columns=[xaxis, yaxis, snapshot, color_name])
    px.scatter(df, x=xaxis, y=yaxis, animation_frame=snapshot, range_x=[lines_list[:, 0].min(), lines_list[:, 0].max()], range_y=[lines_list[:, 1].min(), lines_list[:, 1].max()], hover_name=hover, color=color_name, **kwargs).show()


if t.backends.mps.is_available():
    device = t.device('mps')
elif t.cuda.is_available():
    device = t.device('cuda')
else:
    device = t.device('cpu')

In [2]:
N = 128 
def make_fourier_basis(N: int) -> Tuple[Tensor, List[str]]:
    '''
    Returns a pair `fourier_basis, fourier_basis_names`, where `fourier_basis` is
    a `(N, N)` tensor whose rows are Fourier components and `fourier_basis_names`
    is a list of length `N` containing the names of the Fourier components (e.g.
    `["const", "cos 1", "sin 1", ...]`).
    '''
    fourier_basis = t.ones(N, N)
    fourier_basis_names = ['Const']

    for i in range(1, N // 2 + 1):
        # Define each of the cos terms
        fourier_basis[2*i-1] = t.cos(2*t.pi*t.arange(N)*i/N)
        fourier_basis_names.append(f'cos {i}')

        # Define each of the sin terms, excluding the last one if p is even
        if 2*i < N:
            fourier_basis[2*i] = t.sin(2*t.pi*t.arange(N)*i/N)
            fourier_basis_names.append(f'sin {i}')

    # Normalize vectors, and return them
    fourier_basis /= fourier_basis.norm(dim=1, keepdim=True)
    return fourier_basis.to(device), fourier_basis_names


fourier_basis, fourier_basis_names = make_fourier_basis(N)

animate_lines(
    fourier_basis, 
    snapshot_index=fourier_basis_names, 
    snapshot='Fourier Component', 
    title=f'Fourier basis terms for N={N}'
)

In [3]:
x_density = 100
x = t.linspace(0, N, x_density*N)
k = N//2
cosx = t.cos(k * 2 * t.pi / N * x)
sinx = t.sin(k * 2 * t.pi / N * x)
n_lim = 10
x_lim = n_lim * x_density

lines(
    [cosx[:x_lim], sinx[:x_lim]],
    x=x[:x_lim],
    labels=['Final cosine term', 'Final sine term'],
    xaxis='x',
    yaxis='Amplitude', 
    title='Final cosine and sine terms for continuous variables'
)

cosn = t.cos(k * 2 * t.pi / N * t.arange(n_lim+1))
sinn = t.sin(k * 2 * t.pi / N * t.arange(n_lim+1))

lines(
    [cosn, sinn],
    labels=['Final cosine term', 'Final sine term'],
    xaxis='n',
    yaxis='Amplitude', 
    title='Final cosine and sine terms, sampled at discrete integer variables'
)

In [4]:
def fft1d(x: t.Tensor) -> t.Tensor:
    '''
    Returns the 1D Fourier transform of `x`,
    which can be a vector or a batch of vectors.

    x.shape = (..., p)
    '''
    basis, _ = make_fourier_basis(x.shape[-1])
    return basis.to(device) @ x.to(device)

In [5]:
x = 0.5 + 0.3 * t.cos(20 * 2 * t.pi / N * t.arange(N)) + 0.6 * + t.sin(13 * 2 * t.pi / N * t.arange(N))

lines(
    [x],
    labels=['Test sequence'],
    xaxis='n',
    yaxis='Amplitude', 
    title='Test sequence in the input domain'
)

line(
    fft1d(x).pow(2),
    hover=fourier_basis_names,
    xaxis='k',
    yaxis='Amplitude', 
    title='Test sequence in the frequency domain'
)


In [6]:
import torchaudio
import IPython.display as ipd

wav, sr = torchaudio.load('flute.wav')
ipd.Audio(wav.detach().numpy(), rate=sr)

In [7]:
# Keep from 150ms in to 200ms in, for the left channel only
x_min = int(sr*0.15) 
x_max = int(sr*0.2)
wav_sample = wav[0, x_min:x_max]


lines(
    [wav_sample],
    labels=['Test sequence'],
    xaxis='n',
    yaxis='Amplitude', 
    title='Input waveform'
)

ipd.Audio(wav_sample.detach().numpy(), rate=sr)


In [8]:
wav_N = wav_sample.shape[0]
wav_basis, wav_basis_names = make_fourier_basis(wav_N)
wav_spectrum = fft1d(wav_sample).pow(2)
max_k = 300
line(
    wav_spectrum[:max_k],
    hover=wav_basis_names[:max_k],
    xaxis='Fourier basis term',
    yaxis='Amplitude', 
    title='Fourier transform of the input waveform'
)

In [9]:
import math
def frequency_to_note(frequency):
    # Define the reference frequency for A4 (440 Hz)
    A4_frequency = 440
    # Calculate the number of semitones away from A4
    semitones = round(12 * math.log2(frequency / A4_frequency))
    # Define the list of note names
    note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
    # Calculate the octave and note index
    octave = semitones // 12 + 4
    note_index = semitones % 12
    # Return the note name and octave
    return note_names[note_index] + str(octave)

f1 = 25 * sr / wav_N
note1 = frequency_to_note(f1)
f2 = 32 * sr / wav_N
note2 = frequency_to_note(f2)

print(f'F1 = {f1} Hz ({note1})')
print(f'F2 = {f2} Hz ({note2})')

F1 = 500.0 Hz (D4)
F2 = 640.0 Hz (F#4)


In [10]:
def fourier_2d_basis_term(fourier_basis, k: int, l: int) -> Float[Tensor, "N N"]:
    '''
    Returns the 2D Fourier basis term corresponding to the outer product of the
    `k`-th component of the 1D Fourier basis in the `x` direction and the `l`-th
    component of the 1D Fourier basis in the `y` direction.

    Returns a 2D tensor of length `(N, N)`.
    '''
    fourier_basis = fourier_basis.to('cpu')
    return (fourier_basis[l][:, None] * fourier_basis[k][None, :])

def fft2d(tensor: t.Tensor) -> t.Tensor:
    '''
    Retuns the components of `tensor` in the 2D Fourier basis.

    Asumes that the input has shape `(N, N, ...)`, where the
    last dimensions (if present) are the batch dims.
    Output has the same shape as the input.
    '''
    # fourier_basis[k] is the k-th basis vector, which we want to multiply along
    N = tensor.shape[0]
    fourier_basis, fourier_basis_names = make_fourier_basis(N)
    return einops.einsum(
        tensor.cpu(), fourier_basis.cpu(), fourier_basis.cpu(), "pn pm ..., k pn, l pm -> k l ..."
    )

In [11]:
N = 64
fourier_basis, fourier_basis_names = make_fourier_basis(N)

k = 0
l = 1
imshow(fourier_2d_basis_term(fourier_basis, k, l), title=f"2-D Fourier basis term ({k}, {l})")

k = 1
l = 0
imshow(fourier_2d_basis_term(fourier_basis, k, l), title=f"2-D Fourier basis term ({k}, {l})")

k = 1
l = 1
imshow(fourier_2d_basis_term(fourier_basis, k, l), title=f"2-D Fourier basis term ({k}, {l})")

k = 7
l = 3
imshow(fourier_2d_basis_term(fourier_basis, k, l), title=f"2-D Fourier basis term ({k}, {l})")

In [12]:
N = 64
fourier_basis, fourier_basis_names = make_fourier_basis(N)
example_fn = sum([
    4* fourier_2d_basis_term(fourier_basis, 4, 6), 
    7* fourier_2d_basis_term(fourier_basis, 14, 46),
    8* fourier_2d_basis_term(fourier_basis, 30, 50)
])

imshow(example_fn.T, title=f"Example periodic function")

imshow_fourier(
    fft2d(example_fn),
    title='Example periodic function in 2D Fourier basis'
)