In [None]:
!pip3 install sktime
!pip3 install pytorch-lightning
!pip3 install imblearn

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!cp /content/drive/MyDrive/track_1.tar .

In [None]:
!tar -xvf track_1.tar

In [12]:
from utils import *
from dataset import *

In [13]:
df = files_to_df()
df

Unnamed: 0,path,energy,particle,class_name
0,./idao_dataset/train/NR/-1.48580002784729__CYG...,1,0,NR 1
1,./idao_dataset/train/NR/0.6015999913215637__CY...,1,0,NR 1
2,./idao_dataset/train/NR/0.25130000710487366__C...,1,0,NR 1
3,./idao_dataset/private_test/263711b7666887916d...,-1,-1,ER -1
4,./idao_dataset/public_test/257b2d6250263096e68...,-1,-1,ER -1
...,...,...,...,...
29963,./idao_dataset/private_test/a19fe0430bb05c89ae...,-1,-1,ER -1
29964,./idao_dataset/train/ER/-1.2565394639968872__C...,3,1,ER 3
29965,./idao_dataset/train/ER/0.4510352313518524__CY...,10,1,ER 10
29966,./idao_dataset/private_test/a11ba4a2d8750a7e9e...,-1,-1,ER -1


In [17]:
take_samples = 3
_cls = df['class_name'].unique()
idx = np.array([])
for c in _cls:
  if '-1' in c:
    continue
  class_samples = df[df['class_name'] == c]
  idx = np.concatenate((idx, class_samples.index.values[:min(class_samples.shape[0], take_samples)]))



In [159]:
support_df = df[df.index.isin(idx)]
train_df = df[(~df.index.isin(idx)) & (df.particle != -1)]

support_df['particle'].value_counts()

1    15
0    15
Name: particle, dtype: int64

In [224]:
train_df, val_df = train_test_split(train_df, stratify=train_df.particle, test_size=0.5, random_state = 0)
train_df = train_df[:100]
val_df = train_df[:100]
val_df.shape, train_df.shape

((100, 4), (100, 4))

In [226]:
train_df.particle.value_counts()

0    51
1    49
Name: particle, dtype: int64

In [227]:
#Dataset
class ParticleDataset(Dataset):
  def __init__(self, df, transforms=None):
    self.df = df
    self.transforms = transforms
  
  def __len__(self):
    return self.df.shape[0]
  
  def __getitem__(self, _id):
    img_path = self.df.iloc[_id].path
    label = self.df.iloc[_id].particle
    img = read_im(img_path, self.transforms)

    return img.float(), torch.tensor(label)

In [228]:
train = ParticleDataset(train_df)
val = ParticleDataset(val_df)

train_loader = DataLoader(train, 32)
val_loader = DataLoader(val, 32)

In [229]:
#Model
class Embedder(nn.Module):
  def set_zero_seed():
    random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

  def conv2block(self, in_channels, out_channels=64, kernel_size=3): 
    block = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2)
    )
    return block
  
  def __init__ (self, in_channels=3):
    super(Embedder, self).__init__()
    self.convnet1 = self.conv2block(in_channels)
    self.convnet2 = self.conv2block(64)
    self.convnet3 = self.conv2block(64)
    self.convnet4 = self.conv2block(64)
  
  def forward(self, x):
    x = self.convnet1(x)
    x = self.convnet2(x)
    x = self.convnet3(x)
    x = self.convnet4(x)
    x = x.reshape(x.shape[0], -1) #flatten
    return x


In [230]:
def set_zero_seed():
    random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

In [231]:
class ProtoNet(pl.LightningModule):
  def __init__(self, support):
    super(ProtoNet, self).__init__()

    print(self.device)
    set_zero_seed()
    self.encoder = Embedder()
    self.support = support
    self.loss = nn.NLLLoss()

  def forward(self, query):
    support_embeddings = self.encoder(self.support)
    query_embeddings = self.encoder(query)
    #print(query_embeddings.shape)
    prototypes = self.get_prototypes(support_embeddings)

    distances = self.pairwise_distances(query_embeddings, prototypes)

    y_pred = (-distances).softmax(dim=1).argmax(dim=1)

    return y_pred, distances

  def training_step(self, batch, batch_idx):
    imgs, labels = batch
    preds, distances = self(imgs)
    log_p_y = (-distances).log_softmax(dim=1)
  
    loss = self.loss(log_p_y, labels)
    self.log('train_loss', loss, prog_bar=True, logger=True)
    return loss

  def validation_step(self, batch, batch_idx):
      val_images, val_labels = batch
      preds, distances = self(val_images)
      acc = accuracy_score(val_labels, preds)
      self.log('val_acc', acc, prog_bar=True, logger=True)
      return val_labels.cpu(), torch.Tensor(preds.float())

  def configure_optimizers(self):
      optimizer = torch.optim.Adam(self.parameters())
      return optimizer

  def get_prototypes(self, support, k_way=2, n_shot=15):
    class_prototypes = support.reshape(k_way, n_shot, -1).mean(dim=1)
    return class_prototypes

  def pairwise_distances(self, x, y):
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)

    if d != y.size(1):
      print(x.shape, y.shape)
      raise Exception

    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)

    return torch.pow(x - y, 2).sum(2)
  
  

In [232]:
support_df = support_df.sort_values(by='particle')
imgs = []
for i in range(0, len(support_df)):
  imgs.append(read_im(support_df.iloc[i].path).float())

support = torch.stack(imgs)
support.shape

torch.Size([30, 3, 100, 100])

In [233]:
from pytorch_lightning.loggers import *

logger = TensorBoardLogger("tb_logs", name="classification_model")

model = ProtoNet(support)

trainer = pl.Trainer(check_val_every_n_epoch=10, gpus=0, max_epochs=100, callbacks=[LitProgressBar()], logger=logger)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores


cpu




In [None]:
trainer.fit(model, train_loader, val_loader)


  | Name    | Type     | Params
-------------------------------------
0 | encoder | Embedder | 113 K 
1 | loss    | NLLLoss  | 0     
-------------------------------------
113 K     Trainable params
0         Non-trainable params
113 K     Total params
0.452     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…