<a href="https://colab.research.google.com/github/jmhuer/shift_invariant_dictionary_learning/blob/main/tcn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
from kwta import Sparsify1D_kactive 
from synthetic_data import create_synthetic_data
from my_tests import autoencoder

In [5]:
from torch import nn
from tcn import TemporalConvNet
import torch.nn.functional as F

import torch
import torch.optim as optim
torch.manual_seed(42)

<torch._C.Generator at 0x7f54bea8a210>

In [6]:
import plotly.graph_objects as graph
def plot(all_history:list, title:str, log = False):
    """
    input:
        all_history: list of dicts to plot
    ret:
        None: show plotly fig
    """
    #symbol_sequence= ['circle-open', 'circle', 'circle-open-dot', 'square']
    fig = graph.FigureWidget(layout = graph.Layout(title=graph.layout.Title(text=title))) 
    for i in range(len(all_history)):
        fig.add_trace(graph.Scatter(x = all_history[i]["x"], 
                                    y = all_history[i]["y"],
                                    name = all_history[i]["legend"],
                                    mode='markers',
                                    marker_size=5,
                                    marker_symbol=all_history[i]["marker_symbol"])) 
    if log: fig.update_xaxes(type="log")
    fig.show()

In [7]:
synth_data = create_synthetic_data(size = 5000)


beat patter dictionary 
[1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]
[1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
[0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0]
[0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2]
[0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0] 

[1, 2, 0, 0, 1, 2, 0, 0, 1, 2, 0, 0, 1, 2, 0, 0, 1, 2, 0, 0, 1, 2, 0, 0, 1, 2, 0, 0, 1, 2, 0, 0, 1, 2, 0, 0, 1, 2, 0, 0]
[1, 2, 0, 2, 1, 2, 0, 2, 1, 2, 0, 2, 1, 2, 0, 2, 1, 2, 0, 2, 1, 2, 0, 2, 1, 2, 0, 2, 1, 2, 0, 2, 1, 2, 0, 2, 1, 2, 0, 2]
[1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2]
[1, 2,

In [None]:

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Using device: ", device)

model = autoencoder(input_size=1, output_size=400, num_channels=[1, 15, 25,1], kernel_size=10, dropout=0.2, wta_k = 5).to(device)
inputs = torch.tensor(synth_data[:,None,:]).float().to(device)
print("Input size: ", inputs.size())
out = model(inputs)
print("Output size: ", out.size(), "\n")


loss_fn = torch.nn.L1Loss().to(device)
optimizer = optim.SGD(model.parameters(), lr=.05, weight_decay = 0.00001, momentum=0.05) ##this has weight decay just like you implemented
epochs = 500
history = {"loss": []}
for i in range(epochs):
  optimizer.zero_grad()
  output = model(inputs)

  #decaying WTA
  if i % 500 == 0 and i != 0:
      model.wta.k = max(1, model.wta.k - 1)
      print("model.wta.k: ", model.wta.k)

  loss = loss_fn(output, inputs)
  loss.backward()
  optimizer.step()
  history["loss"].append(float(loss))
  if i % 100 == 0:
      print("Epoch : {} \t Loss : {} \t ".format(i, round(float(loss),7)))
      # print("\nneg encoder ", float((model.encoder.weight.ravel() < 0).sum(dim=0)))






Using device:  cuda
Input size:  torch.Size([5000, 1, 400])
Output size:  torch.Size([5000, 1, 400]) 

Epoch : 0 	 Loss : 0.949999 	 
Epoch : 100 	 Loss : 0.9499988 	 
Epoch : 200 	 Loss : 0.9499983 	 
Epoch : 300 	 Loss : 0.9499986 	 


In [None]:

recunstruct = model(torch.tensor(inputs[1:2,:,:]).float())
# print(recunstruct.detach().numpy()[0,0,:])

#perfect plot
original_plot = {"legend": "original", 
                 "x": list(range(0,400)), 
                 "y": synth_data[0,:],
                 "marker_symbol": 'line-ne-open'}
                  
recunstruct_plot = {"legend": "reconstruct", 
             "x": list(range(0,400)), 
             "y": recunstruct.cpu().detach().numpy()[0,0,:],
             "marker_symbol": 'star'}      

plot([recunstruct_plot, original_plot], "Signal Comparison")

In [None]:
print("kernel size: ", model.get_kernels().cpu().size())
kernel1 = model.get_kernels().cpu().numpy()[1].tolist()


# kernels6_plot  = {"legend": "original", 
#                  "x": list(range(0,40)), 
#                  "y": kernel6,
#                  "marker_symbol": 'triangle-up-open'}

#perfect plot
kernels1_plot  = {"legend": "original", 
                 "x": list(range(0,40)), 
                 "y": kernel1,
                 "marker_symbol": 'star'}

plot([kernels1_plot], "kernels_plot")