<a href="https://colab.research.google.com/github/jmhuer/shift_invariant_dictionary_learning/blob/main/tcn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [21]:
from kwta import Sparsify1D_kactive 
from synthetic_data import create_synthetic_data
from torch import nn
from tcn import TemporalConvNet
import torch.nn.functional as F

import torch
import torch.optim as optim

import plotly.graph_objects as graph
def plot(all_history:list, title:str, log = False):
    """
    input:
        all_history: list of dicts to plot
    ret:
        None: show plotly fig
    """
    #symbol_sequence= ['circle-open', 'circle', 'circle-open-dot', 'square']
    fig = graph.Figure(layout = graph.Layout(title=graph.layout.Title(text=title))) 
    for i in range(len(all_history)):
        fig.add_trace(graph.Scatter(x = all_history[i]["x"], 
                                    y = all_history[i]["y"],
                                    name = all_history[i]["legend"],
                                    mode='markers',
                                    marker_size=5,
                                    marker_symbol=all_history[i]["marker_symbol"])) 
    if log: fig.update_xaxes(type="log")
    fig.show()

In [26]:
synth_data = create_synthetic_data(size = 5000)


beat patter dictionary 
[1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]
[1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
[0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0]
[0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2]
[0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0] 

[1, 2, 0, 0, 1, 2, 0, 0, 1, 2, 0, 0, 1, 2, 0, 0, 1, 2, 0, 0, 1, 2, 0, 0, 1, 2, 0, 0, 1, 2, 0, 0, 1, 2, 0, 0, 1, 2, 0, 0]
[1, 2, 0, 2, 1, 2, 0, 2, 1, 2, 0, 2, 1, 2, 0, 2, 1, 2, 0, 2, 1, 2, 0, 2, 1, 2, 0, 2, 1, 2, 0, 2, 1, 2, 0, 2, 1, 2, 0, 2]
[1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2]
[1, 2,

In [None]:

class autoencoder(nn.Module):
    def __init__(self, input_size, output_size, num_channels, kernel_size, dropout, wta_k):
        super(autoencoder, self).__init__()
        self.wta = Sparsify1D_kactive(k = wta_k)
        self.feature = TemporalConvNet(input_size, num_channels, kernel_size, dropout=dropout)
        self.encoder = torch.nn.Conv1d(in_channels=25, out_channels=10, kernel_size=40, padding=0, bias=False, stride=40)
        self.decoder = torch.nn.ConvTranspose1d(in_channels=10, out_channels=1, kernel_size=40, padding=0, bias=False, stride=40)
        self.encoder.weight.data.normal_(0, 0.1)
        self.decoder.weight.data.normal_(0, 0.1)
        self.code = None
    def get_kernels(self):
        return self.decoder.weight.data[:,0,:]
    # def feature_map(self, x):
    #     code = self.wta(self.encoder(x))
    #     return code
    def forward(self, x):
        # x needs to have dimension (N, C, L) in order to be passed into CNN
        output = self.feature(x)
        # print("~~~~~~~~feature size ", output.size())
        self.code = self.wta(self.encoder(output))
        # # print("~~~~~~~~code size ", code.size())
        output = self.decoder(self.code )
        return output


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Using device: ", device)

model = autoencoder(input_size=1, output_size=400, num_channels=[10,25], kernel_size=10, dropout=0.2, wta_k = 5).to(device)
print("first conv kernels size: ", model.get_kernels().cpu().size())

inputs = torch.tensor(synth_data[:,None,:]).float().to(device)
print("Input size: ", inputs.size())
out = model(inputs)
print("Output size: ", out.size(), "\n")


loss_fn = torch.nn.L1Loss().to(device)
optimizer = optim.Adam(model.parameters(), lr=.001) ##this has weight decay just like you implemented
epochs = 3000
history = {"loss": []}
for i in range(epochs):
  optimizer.zero_grad()
  output = model(inputs)

  #decaying WTA
  if i % 500 == 0 and i != 0:
      model.wta.k = max(1, model.wta.k - 1)
      print("model.wta.k: ", model.wta.k)

  loss = loss_fn(output, inputs)
  loss.backward()
  torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
  optimizer.step()
  history["loss"].append(float(loss))
  if i % 1 == 0:
      print("Epoch : {} \t Loss : {} \t ".format(i, round(float(loss),7)))
      # print("\nneg encoder ", float((model.encoder.weight.ravel() < 0).sum(dim=0)))






Using device:  cuda
first conv kernels size:  torch.Size([10, 40])
Input size:  torch.Size([5000, 1, 400])
Output size:  torch.Size([5000, 1, 400]) 

Epoch : 0 	 Loss : 0.9904435 	 
Epoch : 1 	 Loss : 0.9748788 	 
Epoch : 2 	 Loss : 0.964456 	 
Epoch : 3 	 Loss : 0.9583679 	 
Epoch : 4 	 Loss : 0.9561272 	 
Epoch : 5 	 Loss : 0.955913 	 
Epoch : 6 	 Loss : 0.9557047 	 
Epoch : 7 	 Loss : 0.9548104 	 
Epoch : 8 	 Loss : 0.9532573 	 
Epoch : 9 	 Loss : 0.9521345 	 
Epoch : 10 	 Loss : 0.9518078 	 
Epoch : 11 	 Loss : 0.951612 	 
Epoch : 12 	 Loss : 0.9513249 	 
Epoch : 13 	 Loss : 0.9510208 	 
Epoch : 14 	 Loss : 0.9506552 	 
Epoch : 15 	 Loss : 0.9502353 	 
Epoch : 16 	 Loss : 0.9497265 	 
Epoch : 17 	 Loss : 0.9492763 	 
Epoch : 18 	 Loss : 0.9487372 	 
Epoch : 19 	 Loss : 0.9477597 	 
Epoch : 20 	 Loss : 0.9464539 	 
Epoch : 21 	 Loss : 0.9446581 	 
Epoch : 22 	 Loss : 0.9417118 	 
Epoch : 23 	 Loss : 0.9372641 	 
Epoch : 24 	 Loss : 0.9308997 	 
Epoch : 25 	 Loss : 0.9221087 	 
Epoch

In [37]:

recunstruct = model(torch.tensor(inputs[1:2,:,:]).float())
# print(recunstruct.cpu().detach().numpy()[0,0,:])

#perfect plot
original_plot = {"legend": "original", 
                 "x": list(range(0,400)), 
                 "y": synth_data[0,:],
                 "marker_symbol": 'line-ne-open'}
                  
recunstruct_plot = {"legend": "reconstruct", 
             "x": list(range(0,400)), 
             "y": recunstruct.cpu().detach().numpy()[0,0,:],
             "marker_symbol": 'star'}      

plot([recunstruct_plot, original_plot], "Signal Comparison")


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [66]:
print("first conv kernels size: ", model.get_kernels().cpu().size())
kernel1 = model.get_kernels().cpu().numpy()[8].tolist()


# kernels6_plot  = {"legend": "original", 
#                  "x": list(range(0,40)), 
#                  "y": kernel6,
#                  "marker_symbol": 'triangle-up-open'}

#perfect plot
kernels1_plot  = {"legend": "original", 
                 "x": list(range(0,40)), 
                 "y": kernel1,
                 "marker_symbol": 'star'}

plot([kernels1_plot], "kernels_plot")

first conv kernels size:  torch.Size([10, 40])


In [67]:
#
kernel = 8

model(torch.tensor(inputs[0:1,:,:]))
feature = model.code.float().detach().cpu().numpy() #all 
print(feature)
feature = model.code.float().detach().cpu().numpy()[0][kernel] #only 3rd 
print(feature)
# feature plot


feature_plot  = {"legend": "feature", 
                 "x": list(range(0,40)), 
                 "y": feature,
                 "marker_symbol": 'star'}



plot([feature_plot], "feature_plot")


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



[[[ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    1.1330406e+01  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  4.2825422e+00  0.0000000e+00
    0.0000000e+00  4.3792367e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  4.3837686e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00]
  [-5.7565202e-03  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
    0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00  0.0000

In [69]:
kernel3 = (model.get_kernels().cpu().numpy()*8)[kernel].tolist()
print(kernel3)


section = 4

kernels3_plot  = {"legend": "kernels2_plot", 
                 "x": list(range(0,40)), 
                 "y": kernel3,
                 "marker_symbol": 'triangle-up-open'}

section2_plot  = {"legend": "section2_plot", 
                 "x": list(range(0,40)), 
                 "y": inputs[0:1,:,:].cpu().numpy().ravel()[section*40:(section+1)*40],
                 "marker_symbol": 'circle'}


print(inputs[0:1,:,:].cpu().numpy().ravel()[section*40:(section+1)*40])

plot([kernels3_plot,section2_plot], "kernel and features")

[0.0025168266147375107, -0.00463726744055748, 3.0443341732025146, 0.0006762187695130706, 0.0005618548020720482, -0.0023038694635033607, 3.0442776679992676, -0.00042039406253024936, -0.005215053912252188, 0.005596253089606762, 3.044457197189331, -0.0003199624188710004, -0.0012108583468943834, 0.003112421603873372, 3.0442235469818115, 0.0016133723547682166, -0.004639571998268366, 0.0018471474759280682, 3.0409607887268066, -0.001687465701252222, -0.003374671097844839, 0.002791352104395628, 3.0445451736450195, -0.0017593556549400091, 0.0038986632134765387, 0.003742198459804058, 3.044677734375, 0.0022296763490885496, 0.007303720340132713, 0.002730142790824175, 3.0441970825195312, -0.0024917293339967728, 0.005525171756744385, -0.0014391691656783223, 3.0445168018341064, -0.0016820632154121995, 0.0018678298220038414, -0.001747415866702795, 3.0440211296081543, -0.0036745688412338495]
[0. 0. 3. 0. 0. 0. 3. 0. 0. 0. 3. 0. 0. 0. 3. 0. 0. 0. 3. 0. 0. 0. 3. 0.
 0. 0. 3. 0. 0. 0. 3. 0. 0. 0. 3. 0. 0.