In [1]:
import os
import pandas as pd
import librosa
import librosa.display
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import tqdm.notebook as tqdm
from torchsummary import summary
import torch.optim as optim
os.listdir('data/freesound-audio-tagging')

['audio_test',
 'audio_train',
 'sample_submission.csv',
 'test_post_competition.csv',
 'train.csv',
 'train_post_competition.csv']

In [2]:
len(os.listdir('data/freesound-audio-tagging/audio_train'))
df = pd.read_csv('data/freesound-audio-tagging/train.csv')
df.head()

Unnamed: 0,fname,label,manually_verified
0,00044347.wav,Hi-hat,0
1,001ca53d.wav,Saxophone,1
2,002d256b.wav,Trumpet,0
3,0033e230.wav,Glockenspiel,1
4,00353774.wav,Cello,1


In [3]:
sr = 44100
input_length = int(sr)
batch_size = 32


def audio_norm(data):
    max_data = np.max(data)
    min_data = np.min(data)
    data = (data-min_data)/(max_data-min_data+1e-6)
    return data-0.5


def load_audio_file(file_path, input_length=input_length):
    data = librosa.core.load(file_path, sr=sr)[0] 
    
    if len(data)>input_length:
        max_offset = len(data)-input_length
        offset = np.random.randint(max_offset)
        data = data[offset:input_length+offset]
        
    else:
        T = np.zeros(input_length, dtype=float)
        T[:len(data)] = data
        data = T
        #max_offset = input_length - len(data)
        ##offset = np.random.randint(max_offset)
        #data = np.pad(data, (offset, input_length - len(data) - offset), "constant")
        
    data = audio_norm(data)
    return np.array([data])

In [4]:
librosa.display.waveplot(load_audio_file('data/freesound-audio-tagging/audio_train/001ca53d.wav')[0],
                         sr=sr, 
                         max_points=50000.0, 
                         x_axis='time', 
                         offset=0.0)

<matplotlib.collections.PolyCollection at 0x1296a495f48>

# Data Loading

In [5]:
labels = sorted(set(df.label))
label_to_indice = {l:i for i,l in enumerate(labels)}
indice_to_label = {i:l for i,l in enumerate(labels)}

In [6]:
class FreeSoundDataset(torch.utils.data.Dataset):
    def __init__(self, df_path, data_path, train=True, split=0.8):
        
        self.df = pd.read_csv(df_path)
        
        self.df = self.df[:int(len(self.df)*split)] if train else self.df[int(len(self.df)*split):]
        
        self.data_path = data_path
        self.sr = 44100
        self.input_length = int(sr)
        self.batch_size = 32
        
        self.labels = sorted(set(self.df.label))
        self.label_to_indice = {l:i for i,l in enumerate(self.labels)}
        self.indice_to_label = {i:l for i,l in enumerate(self.labels)}
    
    def __len__(self):
        return len(self.df)-2
    
    def __getitem__(self, idx):
        file_path = self.data_path + list(df[idx: idx+1].fname)[0]
        label_indice = label_to_indice[list(df[idx: idx+1].label)[0]]
        return load_audio_file(file_path), label_indice
        

In [7]:
def bandpass_filter(signal, low, high, order = 5):
    sos = butter(order, [low, high], analog = False, btype = 'band', output = 'sos')
    y = sosfilt(sos, signal)
    return y
    
def make_signal(raw_signal, nyq = sr/2):
    return_signal = np.zeros((8, self.input_length))
    return_signal[0] = raw_signal

    cut_offs = [i/nyq for i in [1, 256, 512, 1024, 2048, 4096, 8192, 11024]]
    for i in range(1, len(cut_offs), 1):
        return_signal[i] = bandpass_filter(raw_signal, cut_offs[i-1], cut_offs[i])
    return return_signal
        
    
def shuffletwo(x, y):
    rng_state = np.random.get_state()
    np.random.shuffle(x)
    np.random.set_state(rng_state)
    np.random.shuffle(y)

In [8]:
FreeSoundData = FreeSoundDataset('data/freesound-audio-tagging/train.csv',
                                 'data/freesound-audio-tagging/audio_train/')
FreeSoundDataTest = FreeSoundDataset('data/freesound-audio-tagging/train.csv',
                                     'data/freesound-audio-tagging/audio_train/',
                                     train=False)
FreeSoundDataLoader = DataLoader(FreeSoundData, batch_size=64, shuffle=True)
FreeSoundDataTestLoader = DataLoader(FreeSoundDataTest, batch_size=64, shuffle=32)

# Model

In [9]:
class FreeSound_Sense(torch.nn.Module):
    
    def __init__(self):
        super(FreeSound_Sense, self).__init__()
        self.conv1d_1_16_9 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=9, padding=True)
        self.conv1d_16_16_9 = nn.Conv1d(in_channels=16, out_channels=16, kernel_size=9, padding=True)
        self.conv1d_16_32_3 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, padding=True)
        self.conv1d_32_32_3 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, padding=True)
        self.conv1d_32_256_3 = nn.Conv1d(in_channels=32, out_channels=256, kernel_size=3, padding=True)
        self.conv1d_256_256_3 = nn.Conv1d(in_channels=256, out_channels=256, kernel_size=3, padding=True)
        
        self.maxpool_16 = nn.MaxPool1d(16)
        self.maxpool_4 = nn.MaxPool1d(4)
        
        self.relu = nn.ReLU()
        self.sigm = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=1)
        self.dropout = nn.Dropout(0.1)
        
        self.fc_256_64 = nn.Linear(in_features=256, out_features=64)
        self.fc_64_1024 = nn.Linear(in_features=64, out_features=1024)
        self.fc_1024_42 = nn.Linear(in_features=1024, out_features=42)
        
        
    def forward(self, x):
        
        # First Block
        x = self.conv1d_1_16_9(x)
        x = self.relu(x)
        x = self.conv1d_16_16_9(x)
        x = self.relu(x)
        x = self.maxpool_16(x)
        x = self.dropout(x)
        
        # Second Block
        x = self.conv1d_16_32_3(x)
        x = self.relu(x)
        x = self.conv1d_32_32_3(x)
        x = self.relu(x)
        x = self.maxpool_4(x)
        x = self.dropout(x)
        
        # Third Block
        x = self.conv1d_32_32_3(x)
        x = self.relu(x)
        x = self.conv1d_32_32_3(x)
        x = self.relu(x)
        x = self.maxpool_4(x)
        x = self.dropout(x)
        
        # Fourth Block
        x = self.conv1d_32_256_3(x)
        x = self.relu(x)
        x = self.conv1d_256_256_3(x)
        x = self.relu(x)
        x = torch.mean(x, 2)
 
        # Final Layers
        x = torch.flatten(x, start_dim=1)
        x = self.fc_256_64(x)
        x = self.relu(x)
        x = self.fc_64_1024(x)
        x = self.relu(x)
        x = self.fc_1024_42(x)
        x = self.softmax(x)
        
        return x

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Model = FreeSound_Sense()
Model.float()
Model.to(device)
summary(Model, (1, 44100))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1            [-1, 16, 44094]             160
              ReLU-2            [-1, 16, 44094]               0
            Conv1d-3            [-1, 16, 44088]           2,320
              ReLU-4            [-1, 16, 44088]               0
         MaxPool1d-5             [-1, 16, 2755]               0
           Dropout-6             [-1, 16, 2755]               0
            Conv1d-7             [-1, 32, 2755]           1,568
              ReLU-8             [-1, 32, 2755]               0
            Conv1d-9             [-1, 32, 2755]           3,104
             ReLU-10             [-1, 32, 2755]               0
        MaxPool1d-11              [-1, 32, 688]               0
          Dropout-12              [-1, 32, 688]               0
           Conv1d-13              [-1, 32, 688]           3,104
             ReLU-14              [-1, 

In [11]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
#Model = FreeSound_Sense()
#Model.float()
#Model.to(device)

Model = FreeSound_Sense()
Model.load_state_dict(torch.load("FreeSound_1D_conv_global_pool_1013_epoch.stDict"))
Model.float()
Model.to(device)

FreeSound_Sense(
  (conv1d_1_16_9): Conv1d(1, 16, kernel_size=(9,), stride=(1,), padding=(True,))
  (conv1d_16_16_9): Conv1d(16, 16, kernel_size=(9,), stride=(1,), padding=(True,))
  (conv1d_16_32_3): Conv1d(16, 32, kernel_size=(3,), stride=(1,), padding=(True,))
  (conv1d_32_32_3): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(True,))
  (conv1d_32_256_3): Conv1d(32, 256, kernel_size=(3,), stride=(1,), padding=(True,))
  (conv1d_256_256_3): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(True,))
  (maxpool_16): MaxPool1d(kernel_size=16, stride=16, padding=0, dilation=1, ceil_mode=False)
  (maxpool_4): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (relu): ReLU()
  (sigm): Sigmoid()
  (softmax): Softmax(dim=1)
  (dropout): Dropout(p=0.1, inplace=False)
  (fc_256_64): Linear(in_features=256, out_features=64, bias=True)
  (fc_64_1024): Linear(in_features=64, out_features=1024, bias=True)
  (fc_1024_42): Linear(in_features=1024, out_features=

In [12]:
criterion = nn.CrossEntropyLoss()
#optimizer = optim.SGD(Model.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(Model.parameters(), lr=0.0005, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

In [13]:
epoch_progress_bar = tqdm.tqdm(range(0, 50))
for epoch in epoch_progress_bar:
    avg_epoch_loss = 0
    data_progress_bar = tqdm.tqdm(FreeSoundDataLoader)
    positives=0
    for data, targets in data_progress_bar:
        data = data.float().to(device)
        targets = targets.long().to(device)
        

        optimizer.zero_grad()
        outputs = Model(data)
        

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        loss_val = loss.item()
        avg_epoch_loss+=loss_val
        data_progress_bar.set_description(desc="Loss: "+str(loss_val))
        
        outputs = np.argmax(outputs.detach().cpu().numpy(), axis=1)
        targets = targets.cpu().numpy()
        positives += np.sum(targets==outputs)
    
    print('Epoch Loss: ', str(avg_epoch_loss/len(FreeSoundDataLoader)))
    print('Train Acc ', str(positives*100/(len(FreeSoundDataLoader)*64)))
    
    # Validation
    data_test_progress_bar = tqdm.tqdm(FreeSoundDataTestLoader)
    positives=0
    for data, targets in data_test_progress_bar:
        data = data.float().to(device)
        targets = targets.numpy()
        outputs = Model(data)
        outputs = np.argmax(outputs.detach().cpu().numpy(), axis=1)
        positives += np.sum(targets==outputs)
        
    print('Valid Acc ', str(positives*100/(len(FreeSoundDataTestLoader)*64)))

HBox(children=(IntProgress(value=0, max=50), HTML(value='')))

HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.4170745681313908
Train Acc  35.819327731092436


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  35.104166666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.400642972032563
Train Acc  37.5


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  35.46875


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.4154986854360887
Train Acc  36.01628151260504


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  35.416666666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.4102261306858863
Train Acc  36.4889705882353


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  34.791666666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.4056575378450025
Train Acc  37.05357142857143


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  36.666666666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.410636525194184
Train Acc  36.46271008403362


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  35.15625


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.402416183167145
Train Acc  37.329306722689076


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  37.8125


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.4129187339494207
Train Acc  36.31827731092437


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  37.447916666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.399241351279892
Train Acc  37.86764705882353


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  35.364583333333336


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.4039904590414354
Train Acc  37.21113445378151


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  36.770833333333336


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.405817332387972
Train Acc  37.07983193277311


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  34.6875


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.4011304618931617
Train Acc  37.381827731092436


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  37.604166666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.4082905364637615
Train Acc  36.76470588235294


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  34.895833333333336


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.397590270563334
Train Acc  37.85451680672269


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  36.770833333333336


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.4099383093729743
Train Acc  36.47584033613445


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  35.052083333333336


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.406058702148309
Train Acc  36.96165966386555


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  36.666666666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.401787479384607
Train Acc  37.34243697478992


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  36.875


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.402174005989267
Train Acc  37.39495798319328


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  32.916666666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.4099305717884993
Train Acc  36.528361344537814


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  36.666666666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.4074488487564216
Train Acc  36.830357142857146


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  37.083333333333336


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.3978748421709075
Train Acc  37.86764705882353


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  37.291666666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.402312943915359
Train Acc  37.171743697478995


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  34.947916666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.405321762341411
Train Acc  37.00105042016807


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  35.104166666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.417162959315196
Train Acc  35.924369747899156


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  35.572916666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.411820650100708
Train Acc  36.42331932773109


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  37.5


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.403294146561823
Train Acc  37.10609243697479


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  35.729166666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.4014707633427212
Train Acc  37.355567226890756


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  37.604166666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.3966744787552776
Train Acc  37.95955882352941


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  37.239583333333336


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.4146818213102197
Train Acc  36.081932773109244


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  36.09375


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.40834135969146
Train Acc  36.73844537815126


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  35.208333333333336


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.4249439860592368
Train Acc  35.07090336134454


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  34.84375


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.420858699734471
Train Acc  35.47794117647059


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  35.104166666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.404752797439319
Train Acc  37.00105042016807


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  36.40625


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.40894713121302
Train Acc  36.67279411764706


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  36.5625


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.40185265581147
Train Acc  37.408088235294116


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  37.1875


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.3988591422553824
Train Acc  37.73634453781513


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  36.666666666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.397170886272142
Train Acc  37.89390756302521


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  34.21875


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.404917813148819
Train Acc  37.05357142857143


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  37.34375


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.40277135271986
Train Acc  37.31617647058823


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  35.260416666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.399445671995147
Train Acc  37.696953781512605


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  37.135416666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.3993173827644156
Train Acc  37.644432773109244


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  34.270833333333336


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.403479371752058
Train Acc  37.171743697478995


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  38.59375


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.3994409657326066
Train Acc  37.565651260504204


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  36.614583333333336


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.404452293860812
Train Acc  37.18487394957983


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  35.572916666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.4046517279969546
Train Acc  37.171743697478995


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  37.96875


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.4016839235770604
Train Acc  37.460609243697476


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  34.166666666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.397762046140783
Train Acc  37.71008403361345


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  38.697916666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.392799681976062
Train Acc  38.28781512605042


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  35.885416666666664


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.4110128558984325
Train Acc  36.43644957983193


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  35.208333333333336


HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


Epoch Loss:  3.419107308908671
Train Acc  35.635504201680675


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


Valid Acc  36.770833333333336



In [14]:
torch.save(Model.state_dict(), "FreeSound_44100_1D_conv_global_pool_50_epoch.stDict")

In [52]:
i=0
for data in FreeSoundDataLoader:
    if i==2:
        O = Model(data[0].float().to(device))
        A = data[1]
        #print(data[1])
        break
    i+=1
O = O.detach().cpu().numpy()
K = np.argmax(O, axis=1)==A.numpy()
print(np.sum(K)/len(K), np.sum(K))
list(zip(A, K))

0.375 12


[(tensor(11), False),
 (tensor(1), False),
 (tensor(32), True),
 (tensor(39), True),
 (tensor(23), False),
 (tensor(26), True),
 (tensor(14), False),
 (tensor(37), False),
 (tensor(34), False),
 (tensor(6), False),
 (tensor(22), False),
 (tensor(36), True),
 (tensor(1), False),
 (tensor(38), False),
 (tensor(0), True),
 (tensor(35), False),
 (tensor(17), False),
 (tensor(10), False),
 (tensor(21), False),
 (tensor(12), True),
 (tensor(23), True),
 (tensor(1), False),
 (tensor(6), True),
 (tensor(5), False),
 (tensor(21), False),
 (tensor(1), False),
 (tensor(35), True),
 (tensor(34), False),
 (tensor(32), True),
 (tensor(28), False),
 (tensor(12), True),
 (tensor(12), True)]