In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as fn

In [None]:
cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)

In [None]:
def augment(x):
    shift = np.random.randint(-5, 5)
    l = x.shape[2]
    if shift >= 0:
        xn = np.zeros_like(x) + x[:,:,0:1]
        xn[:,0,shift:] = x[:,0,:l-shift]
    else:
        xn = np.zeros_like(x) + x[:,:,-1:]
        xn[:,0,:shift] = x[:,0,-shift:]
    return xn + (np.random.random(size=(1,1,1)) - 0.5) * 0.05

In [None]:
def get_2N(data):
    d1 = []
    d2 = []
    for d in data:
        d1.append(torch.from_numpy(augment(d)).float())
        d2.append(torch.from_numpy(augment(d)).float())
    return torch.stack(d1, 0), torch.stack(d2, 0)


In [None]:
def calculate_loss(o1, o2, debug=False):

    N = o1.shape[0]
        
    # N x N
    sim_all = cos(o1.unsqueeze(1), o2.unsqueeze(0))
    exp_all = torch.exp(sim_all)
        
    # Get matching augmentation pairs
    mask = torch.eye(N).cuda()
    nom_all = (exp_all * mask).sum(dim=1) * 2
        
    # Get other pairs    
    mask_inv = 1 - mask
    exp_all = exp_all * mask_inv

    sum_dist_all = exp_all.sum(dim=1) + exp_all.sum(dim=0)

    loss_all = -torch.log(nom_all / (nom_all + sum_dist_all))
    loss = loss_all.sum()

    return loss


In [None]:
# Get Data
all_b = []
all_l = []
len_data = 50
a = np.linspace(0, 10, num=len_data)

#for type in np.random.choice(6, 10000, p=[0.2, 0.1, 0.1, 0.2, 0.2, 0.2]):
for type in np.random.choice(6, 10000, p=[0.2, 0.1, 0.1, 0.1, 0.1, 0.4]):

    offset = (np.random.random() - 0.5) * 5 + 5
    all_l.append(type)

    if type == 0:
        b = np.exp(-(a - offset)**2)
    elif type == 1:
        b = -np.exp(-(a - offset)**2)
    elif type == 2:
        b = np.exp(-(a - offset)**2 / 10)
    elif type == 3:
        b = -np.exp(-(a - offset)**2 / 10)
    elif type == 4:
        b1 = np.exp(-(a - offset)**2 / 10)
        b2 = -np.exp(-(a - offset)**2)
        b = b1 + b2
    else:
        b1 = -np.exp(-(a - offset)**2 / 10)
        b2 = np.exp(-(a - offset)**2)
        b = b1 + b2

    b += (np.random.random(len_data) - 0.5) * 0.05
    all_b.append(b)
    
data_org = np.array(all_b, dtype=np.float32)
data_org = np.expand_dims(data_org, 1)
data_org = np.expand_dims(data_org, 1)
labels = np.array(all_l)

unq, unq_index = np.unique(labels, return_index=True)

print("d_shape", data_org.shape)
print("l_shape", labels.shape)

In [None]:
for i in unq_index:
    plt.plot(data_org[i].squeeze())
    plt.show()

In [None]:
# Model

class encoderc(torch.nn.Module):
    def __init__(self):
        super(encoderc, self).__init__()
        self.e1 = torch.nn.Conv2d(1, 8, kernel_size=(1,5), stride=(1,2))
        _ = torch.nn.init.xavier_uniform_(self.e1.weight, 1.5)
        self.e2 = torch.nn.Conv2d(8, 16, kernel_size=(1,5), stride=(1,2))
        _ = torch.nn.init.xavier_uniform_(self.e2.weight, 1.5)
        self.e3 = torch.nn.Conv2d(16, 32, kernel_size=(1,3), stride=(1,2))
        _ = torch.nn.init.xavier_uniform_(self.e3.weight, 1.5)
        self.e4 = torch.nn.Conv2d(32, 64, kernel_size=(1,3), stride=(1,2))
        _ = torch.nn.init.xavier_uniform_(self.e4.weight, 1.5)
        self.e5 = torch.nn.Linear(64, 32)
        _ = torch.nn.init.xavier_uniform_(self.e5.weight, 1.5)
        self.e6 = torch.nn.Linear(32, 2)
        _ = torch.nn.init.xavier_uniform_(self.e6.weight, 1.5)

    def forward(self, x):
        e = fn.relu(self.e1(x))
        e = fn.relu(self.e2(e))
        e = fn.relu(self.e3(e))
        e = fn.relu(self.e4(e))
        e = torch.flatten(e, start_dim=1)
        e = fn.relu(self.e5(e))
        e = self.e6(e)
        return e

In [None]:
model = encoderc().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 256
batch_size = 512

In [None]:
for epoch in range(num_epochs):
    idx = np.random.permutation(len(data_org))
    data_org = data_org[idx]
    labels = labels[idx]
    
    if epoch % 10 == 0:
        o_all = model(torch.from_numpy(data_org).cuda())
        oan = o_all.cpu().detach().numpy()

        plt.scatter(oan[:,0], oan[:,1], c=labels, cmap="rainbow", alpha=0.5)
        plt.show()
    
    for idx in range(0, len(data_org), batch_size):
        data = data_org[idx:idx+batch_size]
        
        # Get 2N Augmentations
        d1, d2 = get_2N(data)

        # Forward
        o1 = model(d1.cuda())
        o2 = model(d2.cuda())
        loss = calculate_loss(o1, o2)
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print('epoch [{}/{}], loss:{:.4f}'
      .format(epoch + 1, num_epochs, loss.item()))
