In [1]:
import torch
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch.nn as nn
import torch.nn.functional as F

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))
print("RUN CELL")

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/das-cnn-cars/data.pt
RUN CELL


In [2]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, imgs, counts, labels, ids, transforms = lambda x: x):
      self.imgs = imgs
      self.labels = labels
      self.counts = counts
      self.labels = labels
      self.transforms = transforms
      self.ids = ids

    def __getitem__(self, i, return_id = False):
      if return_id:
        return (self.transforms(self.imgs[i].unsqueeze(0)), self.labels[i], self.counts[i], self.ids[i])
      return (self.transforms(self.imgs[i].unsqueeze(0)), self.labels[i], self.counts[i])

    def __len__(self):
      return len(self.imgs)

In [3]:
class SimpleVehicleNet(nn.Module):
    def __init__(self, num_bins=8, num_count_classes = 7):  # num_bins = number of vehicle count buckets
        super(SimpleVehicleNet, self).__init__()
        self.num_count_classes = num_count_classes
        # Input: [B, 1, 585, 130]
        self.backbone = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),  # [B, 16, 585, 130]
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),  # [B, 16, 292, 65]

            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),  # [B, 32, 292, 65]
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),  # [B, 32, 146, 32]

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # [B, 64, 146, 32]
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),  # [B, 64, 1, 1]
        )

        self.flatten = nn.Flatten()  # → [B, 64]

        # Regression head (predict scalar count)
        self.count_head = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, num_count_classes)
        )

        # Distribution head (predict vehicle count histogram)
        self.dist_head = nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, num_bins)  # Output: raw logits → apply log_softmax
        )
        self.mse = nn.MSELoss()
        self.mae = nn.L1Loss()
        self.cross_entropy = nn.CrossEntropyLoss()
        self.kl = nn.KLDivLoss(reduction = "batchmean")
        self.lambda_count = 10
        self.lambda_kl = 1
        self.lambda_mse = 1
        self.freq = torch.Tensor([3.8822e-02, 5.4998e-03, 8.6703e-01, 4.8528e-03, 0.0000e+00, 0.0000e+00,
        3.2352e-04, 8.3468e-02])

    def masked_mse(self, y_pred, y_true):
        # Create a mask where target is non-zero
        mask = (y_true != 0).float()
        
        # Compute squared error only where mask == 1
        loss = (mask * (y_pred - y_true) ** 2)
        
        # Avoid division by zero: normalize by number of non-zero elements
        return loss.sum() / (mask.sum() + 1e-8)

    def weighted_mse(self, y_pred, y_true):
        weights = 1.0 / (self.freq + 1e-6)
        weights = weights / weights.sum()
        weights = weights.to(y_pred.device)
        loss = weights * (y_pred - y_true) ** 2
        return loss.mean()
    
    def forward(self, x):
        x = self.backbone(x)
        x = self.flatten(x)
        count = self.count_head(x)                      # shape: [B, 1]
        dist_logits = self.dist_head(x)                # shape: [B, num_bins]
        dist_output = F.softmax(dist_logits, dim=1)
        return count, dist_output

    def loss(self, count_out, hist_out, count_label, hist_label):
        label = F.one_hot(count_label, num_classes = self.num_count_classes).float()
        #print(count_out.shape, label.shape)
        #print(count_out.dtype, label.dtype)
        count_loss = self.cross_entropy(count_out, label)
        row_sums = hist_label.sum(dim=1, keepdim=True)
        normalized = hist_label / row_sums
        #hist_loss = self.weighted_mse(hist_out, hist_label)
        hist_loss = self.cross_entropy(hist_out, normalized)
        return self.lambda_count * count_loss + hist_loss

In [4]:
class CountCNN(nn.Module):
    def __init__(self, classes = 7):  # num_bins = number of vehicle count buckets
        super(CountCNN, self).__init__()

        # Input: [B, 1, 585, 130]
        self.backbone = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),  # [B, 16, 585, 130]
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),  # [B, 16, 292, 65]

            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),  # [B, 32, 292, 65]
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),  # [B, 32, 146, 32]

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # [B, 64, 146, 32]
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),  # [B, 64, 1, 1]
        )

        self.flatten = nn.Flatten()  # → [B, 64]

        # Regression head (predict scalar count)
        self.count_head = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, classes)
        )
        
        self.mse = nn.MSELoss()
        self.mae = nn.L1Loss()
        self.cross_entropy = nn.CrossEntropyLoss()
        self.freq = torch.Tensor([3.8822e-02, 5.4998e-03, 8.6703e-01, 4.8528e-03, 0.0000e+00, 0.0000e+00,
        3.2352e-04, 8.3468e-02])
        self.classes = classes
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.flatten(x)
        count = self.count_head(x)                      # shape: [B, 1]
        return count

    def loss(self, count_out, count_label):
        label = F.one_hot(count_label, num_classes = self.classes).float()
        count_loss = self.cross_entropy(count_out, count_label)
        return count_loss

In [19]:
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F

def append_dropout(model, rate=0.2):
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            append_dropout(module)
        if isinstance(module, nn.ReLU):
            new = nn.Sequential(module, nn.Dropout2d(p=rate))
            setattr(model, name, new)
            
class VehicleCounterNet(nn.Module):
    def __init__(self, num_classes=8, count_classes = 7):
        super(VehicleCounterNet, self).__init__()

        # Load pretrained ResNet18 and modify
        self.backbone = models.resnet18()
        #append_dropout(self.backbone, rate = 0.1)
        # Modify input conv layer to take 1 channel (grayscale)
        self.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Replace the final FC layer with identity so we can define our own heads
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        # Output head 1: Regress or classify number of vehicles
        self.vehicle_count = nn.Sequential(nn.Linear(num_features, count_classes))

        # Output head 2: Predict histogram (e.g., soft count distribution across possible values)
        self.histogram_head = nn.Sequential(nn.Linear(num_features, num_classes))
        self.cross_entropy = nn.CrossEntropyLoss()
        self.freq = torch.Tensor([3.8822e-02, 5.4998e-03, 8.6703e-01, 4.8528e-03, 0.0000e+00, 0.0000e+00,
        3.2352e-04, 8.3468e-02])
        self.weighted_nll = nn.NLLLoss(weight = self.freq)
        #lambda P, Q: -torch.sum(self.freq*P*torch.log(Q + 1e-9))
        self.mse = nn.MSELoss()
        self.kl = nn.KLDivLoss(reduction = "batchmean")
        self.lambda_count = 10
        self.lambda_kl = 1
        self.lambda_mse = 1
        self.count_classes = count_classes

    def forward(self, x):
        features = self.backbone(x)
        count_output = self.vehicle_count(features)
        histogram_output = F.softmax(self.histogram_head(features), dim = 1)
        return count_output, histogram_output

    def masked_mse(self, y_pred, y_true):
        # Create a mask where target is non-zero
        mask = (y_true != 0).float()
        
        # Compute squared error only where mask == 1
        loss = (mask * (y_pred - y_true) ** 2)
        
        # Avoid division by zero: normalize by number of non-zero elements
        return loss.sum() / (mask.sum() + 1e-8)

    def weighted_soft_nll_loss(self, log_probs, soft_targets):
        """
        log_probs: Tensor of shape [batch_size, num_classes], output of log_softmax
        soft_targets: Tensor of shape [batch_size, num_classes], target distributions
        class_weights: Tensor of shape [num_classes], weight per class
        """
        # Expand class_weights to match soft_targets shape
        weights = self.freq.unsqueeze(0).to(log_probs.device)  # [1, num_classes]
        
        # Apply weights to soft_targets
        weighted_targets = soft_targets * weights  # [batch_size, num_classes]
        
        # Compute element-wise product: -weighted_targets * log_probs
        loss = -torch.sum(weighted_targets * log_probs, dim=1)  # [batch_size]
        
        return loss.mean()  # or use .sum() if preferred



    def loss(self, count_out, hist_out, count_label, hist_label):
        count_label = F.one_hot(count_label, self.count_classes).float()
        count_loss = self.cross_entropy(count_out, count_label)
        target = hist_label + 1e-8
        target_probs = (target) / target.sum(dim=1, keepdim=True)
        expected_hist_loss = self.weighted_soft_nll_loss(torch.log(hist_out), target_probs)
        total_loss = (
          self.lambda_count * count_loss +
          self.lambda_mse * expected_hist_loss
        )
        
        return total_loss

In [6]:
import os
from datetime import datetime

class Logger:
  def __init__(self, log_dir, filename='train_log.txt'):
      os.makedirs(log_dir, exist_ok=True)
      timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
      self.log_dir = log_dir
      self.log_path = os.path.join(log_dir, f'{timestamp}_{filename}')

      with open(self.log_path, 'w') as f:
          f.write(f"Logging started: {timestamp}\n\n")

  def log(self, message):
      timestamp = datetime.now().strftime('%H:%M:%S')
      full_message = f"[{timestamp}] {message}"
      print(full_message)  # also print to stdout
      with open(self.log_path, 'a') as f:
          f.write(full_message + '\n')

  def log_metrics(self, epoch, train_loss=None, val_loss=None, **kwargs):
      msg = f"Epoch {epoch} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}"
      for k, v in kwargs.items():
          msg += f" | {k}: {v:.4f}" if isinstance(v, float) else f" | {k}: {v}"
      self.log(msg)

  def save_checkpoint(self, path, model, optimizer, epoch, best_val_loss):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'best_val_loss': best_val_loss
    }
    torch.save(checkpoint, os.path.join(self.log_dir, path))


In [7]:

data = torch.load("/kaggle/input/das-cnn-cars/data.pt")

  data = torch.load("/kaggle/input/das-cnn-cars/data.pt")


In [8]:
print(data)

{'imgs': tensor([[[-76.5721, -77.1576, -81.6660,  ..., -83.9075, -81.3769, -84.3168],
         [-76.3438, -73.9983, -79.6962,  ..., -80.6483, -82.2012, -87.3538],
         [-76.6467, -74.0830, -81.2737,  ..., -81.5403, -83.2910, -88.4794],
         ...,
         [-79.3313, -78.3320, -78.4678,  ..., -81.0519, -80.5077, -88.1365],
         [-85.3054, -79.8374, -75.8968,  ..., -86.9083, -82.3489, -84.4151],
         [-85.2439, -82.8430, -79.4188,  ..., -80.8857, -87.3668, -87.8926]],

        [[-80.6762, -81.7623, -72.8934,  ..., -85.9792, -84.2322, -86.8569],
         [-87.8347, -79.4662, -78.5318,  ..., -81.9379, -86.0406, -86.1646],
         [-86.9451, -79.9672, -75.6647,  ..., -81.5951, -82.0449, -87.3756],
         ...,
         [-83.1090, -84.2905, -83.3142,  ..., -66.5070, -70.2364, -76.1732],
         [-81.8832, -83.0903, -81.6652,  ..., -79.0737, -79.5598, -82.4076],
         [-81.5252, -81.8138, -81.5861,  ..., -80.0465, -84.4012, -88.0107]],

        [[-81.3736, -82.3884, -82.3

In [9]:
import matplotlib.pyplot as plt
counts = data["counts"]
print(torch.unique(counts, return_counts = True))

(tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 17, 18, 19]), tensor([309, 111,  85,  77, 253, 122, 109,  85,  73, 100,  29,  13,  11,   3,
          8,   1,   1,   1]))


In [10]:
filtered_data = {}
filtered_data["imgs"] = []
filtered_data["counts"] = []
filtered_data["labels"] = []
filtered_data["ids"] = []
#filter out when counts <= 6
for i in range(len(counts)):
    if counts[i] <= 6:
        for k in ["imgs", "counts", "labels", "ids"]:
            filtered_data[k].append(data[k][i].tolist())

In [11]:
for k in ["imgs", "counts", "labels", "ids"]:
    filtered_data[k] = torch.tensor(filtered_data[k])
    print(filtered_data[k].shape)

torch.Size([957, 585, 153])
torch.Size([957])
torch.Size([957, 8])
torch.Size([957])


In [12]:
from torch.utils.data import random_split, DataLoader
from torchvision.transforms import v2
images = filtered_data["imgs"]
counts = filtered_data["counts"]
labels = filtered_data["labels"]
ids = filtered_data["ids"]
train_transform = v2.Compose([
    v2.RandomHorizontalFlip(),         # flip image horizontally
    v2.RandomRotation(15),             # rotate by ±15 degrees
    v2.ToTensor(),                     # convert to tensor
    v2.Normalize((0.5, ), (0.5, ))            # normalize
])
dataset = Dataset(images, counts, labels, ids, transforms = v2.Normalize((0.5, ), (0.5, )))
train_ds, test_ds = random_split(dataset, [0.9, 0.1])



In [13]:
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm

In [21]:
model = VehicleCounterNet() #CountCNN()#VehicleCounterNet() #SimpleVehicleNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
logger = Logger("/kaggle/working/182_proj_logs_cross_entropy/")
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
val_loader = DataLoader(test_ds)
# Loss and optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-2)
train_losses = []
val_losses = []
# Training loop
num_epochs = 100
best_val_loss = float('inf')
logger.save_checkpoint("best_model.pth", model, optimizer, 0, best_val_loss)
for epoch in tqdm(range(num_epochs)):
    # --------- Train ---------
    model.train()
    train_loss = 0.0
    train_count_correct = 0
    train_label_correct = 0
    train_total = 0

    for images, labels, counts in train_loader:
        images = images.to(device).float()
        labels = labels.to(device).float()
        counts = counts.to(device).long()
        optimizer.zero_grad()
        count_output, label_output = model(images)
        loss = model.loss(count_output, label_output, counts, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        preds = torch.argmax(count_output, dim = -1)
        label_preds = preds.unsqueeze(-1) * label_output
        train_count_correct += (preds == counts).sum().item()
        train_label_correct += (torch.round(label_preds).flatten() == labels.flatten()).sum().item()
        train_total += labels.size(0)

    avg_train_loss = train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    train_acc_count = train_count_correct / train_total
    train_acc_label = train_label_correct / (8*train_total)

    # --------- Validate ---------
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_label_correct = 0
    val_total = 0

    with torch.no_grad():
        for images, labels, counts in val_loader:
            images = images.to(device).float()
            labels = labels.to(device).float()
            counts = counts.to(device).long()#.unsqueeze(1)
            count_output, label_output = model(images)
            loss = model.loss(count_output, label_output, counts, labels)
            val_loss += loss.item()
            preds = torch.argmax(count_output, dim = -1)
            label_preds = preds.unsqueeze(-1) * label_output
            val_correct += (preds == counts.flatten()).sum().item()
            val_label_correct += (torch.round(label_output).flatten() == labels.flatten()).sum().item()
            val_total += labels.size(0)
            
    
    avg_val_loss = val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    val_acc = val_correct / val_total
    val_acc_label = val_label_correct / (8*val_total)
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        logger.save_checkpoint("best_model.pth", model, optimizer, epoch, best_val_loss)

    # --------- Print results ---------
    
    logger.log(f"Epoch {epoch+1}/{num_epochs} "
          f"Train Loss: {avg_train_loss:.2f} | Count Acc: {train_acc_count:.2f} | Label Acc: {train_acc_label:.2f}"
          f"|| Val Loss: {avg_val_loss:.2f} | Count Acc: {val_acc:.2f} | Label Acc: {val_acc_label:.2f}")
    '''
    logger.log(f"Epoch {epoch+1}/{num_epochs} "
          f"Train Loss: {avg_train_loss:.2f} | Count Acc: {train_acc_count:.2f}"
          f"|| Val Loss: {avg_val_loss:.2f} | Count Acc: {val_acc:.2f}")   
    '''

  1%|          | 1/100 [00:05<09:48,  5.95s/it]

[22:44:59] Epoch 1/100 Train Loss: 18.64 | Count Acc: 0.38 | Label Acc: 0.88|| Val Loss: 19.60 | Count Acc: 0.34 | Label Acc: 0.88


  2%|▏         | 2/100 [00:11<09:42,  5.95s/it]

[22:45:05] Epoch 2/100 Train Loss: 15.37 | Count Acc: 0.44 | Label Acc: 0.90|| Val Loss: 16.64 | Count Acc: 0.37 | Label Acc: 0.88


  3%|▎         | 3/100 [00:17<09:22,  5.80s/it]

[22:45:10] Epoch 3/100 Train Loss: 14.86 | Count Acc: 0.46 | Label Acc: 0.90|| Val Loss: 16.94 | Count Acc: 0.38 | Label Acc: 0.88


  4%|▍         | 4/100 [00:23<09:10,  5.73s/it]

[22:45:16] Epoch 4/100 Train Loss: 14.63 | Count Acc: 0.48 | Label Acc: 0.90|| Val Loss: 25.82 | Count Acc: 0.31 | Label Acc: 0.88


  5%|▌         | 5/100 [00:28<09:05,  5.74s/it]

[22:45:22] Epoch 5/100 Train Loss: 14.22 | Count Acc: 0.48 | Label Acc: 0.90|| Val Loss: 53.32 | Count Acc: 0.20 | Label Acc: 0.88


  6%|▌         | 6/100 [00:34<08:56,  5.71s/it]

[22:45:27] Epoch 6/100 Train Loss: 14.24 | Count Acc: 0.49 | Label Acc: 0.90|| Val Loss: 16.88 | Count Acc: 0.39 | Label Acc: 0.88


  7%|▋         | 7/100 [00:40<08:48,  5.68s/it]

[22:45:33] Epoch 7/100 Train Loss: 13.93 | Count Acc: 0.49 | Label Acc: 0.90|| Val Loss: 19.40 | Count Acc: 0.33 | Label Acc: 0.88


  8%|▊         | 8/100 [00:46<08:50,  5.77s/it]

[22:45:39] Epoch 8/100 Train Loss: 13.86 | Count Acc: 0.49 | Label Acc: 0.90|| Val Loss: 14.16 | Count Acc: 0.45 | Label Acc: 0.88


  9%|▉         | 9/100 [00:51<08:40,  5.72s/it]

[22:45:44] Epoch 9/100 Train Loss: 13.78 | Count Acc: 0.49 | Label Acc: 0.91|| Val Loss: 33.33 | Count Acc: 0.34 | Label Acc: 0.88


 10%|█         | 10/100 [00:57<08:31,  5.69s/it]

[22:45:50] Epoch 10/100 Train Loss: 13.64 | Count Acc: 0.50 | Label Acc: 0.91|| Val Loss: 21.81 | Count Acc: 0.34 | Label Acc: 0.88


 11%|█         | 11/100 [01:03<08:25,  5.68s/it]

[22:45:56] Epoch 11/100 Train Loss: 13.24 | Count Acc: 0.52 | Label Acc: 0.90|| Val Loss: 25.22 | Count Acc: 0.34 | Label Acc: 0.88


 12%|█▏        | 12/100 [01:08<08:16,  5.64s/it]

[22:46:01] Epoch 12/100 Train Loss: 13.12 | Count Acc: 0.51 | Label Acc: 0.91|| Val Loss: 18.58 | Count Acc: 0.38 | Label Acc: 0.88


 13%|█▎        | 13/100 [01:14<08:10,  5.64s/it]

[22:46:07] Epoch 13/100 Train Loss: 12.99 | Count Acc: 0.50 | Label Acc: 0.91|| Val Loss: 15.03 | Count Acc: 0.42 | Label Acc: 0.88


 14%|█▍        | 14/100 [01:19<08:03,  5.62s/it]

[22:46:12] Epoch 14/100 Train Loss: 12.89 | Count Acc: 0.53 | Label Acc: 0.91|| Val Loss: 14.32 | Count Acc: 0.51 | Label Acc: 0.88


 15%|█▌        | 15/100 [01:25<07:57,  5.61s/it]

[22:46:18] Epoch 15/100 Train Loss: 12.88 | Count Acc: 0.53 | Label Acc: 0.91|| Val Loss: 20.84 | Count Acc: 0.37 | Label Acc: 0.88


 16%|█▌        | 16/100 [01:31<07:59,  5.70s/it]

[22:46:24] Epoch 16/100 Train Loss: 13.10 | Count Acc: 0.51 | Label Acc: 0.91|| Val Loss: 13.31 | Count Acc: 0.52 | Label Acc: 0.88


 17%|█▋        | 17/100 [01:36<07:52,  5.70s/it]

[22:46:30] Epoch 17/100 Train Loss: 12.66 | Count Acc: 0.52 | Label Acc: 0.91|| Val Loss: 40.39 | Count Acc: 0.17 | Label Acc: 0.88


 18%|█▊        | 18/100 [01:42<07:45,  5.68s/it]

[22:46:35] Epoch 18/100 Train Loss: 12.63 | Count Acc: 0.52 | Label Acc: 0.91|| Val Loss: 13.33 | Count Acc: 0.52 | Label Acc: 0.88


 19%|█▉        | 19/100 [01:48<07:38,  5.66s/it]

[22:46:41] Epoch 19/100 Train Loss: 12.67 | Count Acc: 0.53 | Label Acc: 0.91|| Val Loss: 13.80 | Count Acc: 0.45 | Label Acc: 0.88


 20%|██        | 20/100 [01:53<07:33,  5.67s/it]

[22:46:47] Epoch 20/100 Train Loss: 12.42 | Count Acc: 0.53 | Label Acc: 0.91|| Val Loss: 16.26 | Count Acc: 0.38 | Label Acc: 0.88


 21%|██        | 21/100 [01:59<07:27,  5.66s/it]

[22:46:52] Epoch 21/100 Train Loss: 12.26 | Count Acc: 0.54 | Label Acc: 0.91|| Val Loss: 15.11 | Count Acc: 0.42 | Label Acc: 0.88


 22%|██▏       | 22/100 [02:05<07:22,  5.67s/it]

[22:46:58] Epoch 22/100 Train Loss: 12.26 | Count Acc: 0.54 | Label Acc: 0.91|| Val Loss: 15.30 | Count Acc: 0.43 | Label Acc: 0.88


 23%|██▎       | 23/100 [02:11<07:24,  5.77s/it]

[22:47:04] Epoch 23/100 Train Loss: 12.22 | Count Acc: 0.53 | Label Acc: 0.91|| Val Loss: 12.98 | Count Acc: 0.47 | Label Acc: 0.88


 24%|██▍       | 24/100 [02:16<07:15,  5.73s/it]

[22:47:10] Epoch 24/100 Train Loss: 12.06 | Count Acc: 0.54 | Label Acc: 0.91|| Val Loss: 15.01 | Count Acc: 0.45 | Label Acc: 0.88


 25%|██▌       | 25/100 [02:22<07:07,  5.70s/it]

[22:47:15] Epoch 25/100 Train Loss: 12.14 | Count Acc: 0.53 | Label Acc: 0.91|| Val Loss: 14.60 | Count Acc: 0.45 | Label Acc: 0.88


 26%|██▌       | 26/100 [02:28<07:00,  5.68s/it]

[22:47:21] Epoch 26/100 Train Loss: 12.08 | Count Acc: 0.53 | Label Acc: 0.91|| Val Loss: 13.48 | Count Acc: 0.49 | Label Acc: 0.88


 27%|██▋       | 27/100 [02:33<06:53,  5.67s/it]

[22:47:26] Epoch 27/100 Train Loss: 12.07 | Count Acc: 0.54 | Label Acc: 0.91|| Val Loss: 16.11 | Count Acc: 0.42 | Label Acc: 0.88


 28%|██▊       | 28/100 [02:39<06:48,  5.67s/it]

[22:47:32] Epoch 28/100 Train Loss: 11.58 | Count Acc: 0.56 | Label Acc: 0.91|| Val Loss: 14.31 | Count Acc: 0.45 | Label Acc: 0.88


 29%|██▉       | 29/100 [02:45<06:41,  5.66s/it]

[22:47:38] Epoch 29/100 Train Loss: 11.60 | Count Acc: 0.54 | Label Acc: 0.91|| Val Loss: 16.53 | Count Acc: 0.48 | Label Acc: 0.88


 30%|███       | 30/100 [02:50<06:36,  5.66s/it]

[22:47:43] Epoch 30/100 Train Loss: 11.62 | Count Acc: 0.55 | Label Acc: 0.91|| Val Loss: 15.20 | Count Acc: 0.41 | Label Acc: 0.88


 31%|███       | 31/100 [02:56<06:30,  5.65s/it]

[22:47:49] Epoch 31/100 Train Loss: 11.71 | Count Acc: 0.54 | Label Acc: 0.91|| Val Loss: 13.64 | Count Acc: 0.45 | Label Acc: 0.88


 32%|███▏      | 32/100 [03:02<06:23,  5.64s/it]

[22:47:55] Epoch 32/100 Train Loss: 11.40 | Count Acc: 0.56 | Label Acc: 0.91|| Val Loss: 31.67 | Count Acc: 0.35 | Label Acc: 0.88


 33%|███▎      | 33/100 [03:07<06:17,  5.64s/it]

[22:48:00] Epoch 33/100 Train Loss: 11.27 | Count Acc: 0.56 | Label Acc: 0.91|| Val Loss: 21.41 | Count Acc: 0.25 | Label Acc: 0.88


 34%|███▍      | 34/100 [03:13<06:12,  5.64s/it]

[22:48:06] Epoch 34/100 Train Loss: 11.25 | Count Acc: 0.56 | Label Acc: 0.91|| Val Loss: 16.24 | Count Acc: 0.44 | Label Acc: 0.88


 35%|███▌      | 35/100 [03:18<06:05,  5.63s/it]

[22:48:12] Epoch 35/100 Train Loss: 11.20 | Count Acc: 0.55 | Label Acc: 0.91|| Val Loss: 72.24 | Count Acc: 0.17 | Label Acc: 0.88


 36%|███▌      | 36/100 [03:24<06:00,  5.64s/it]

[22:48:17] Epoch 36/100 Train Loss: 10.87 | Count Acc: 0.59 | Label Acc: 0.91|| Val Loss: 24.56 | Count Acc: 0.23 | Label Acc: 0.88


 37%|███▋      | 37/100 [03:30<05:54,  5.63s/it]

[22:48:23] Epoch 37/100 Train Loss: 10.79 | Count Acc: 0.57 | Label Acc: 0.91|| Val Loss: 16.09 | Count Acc: 0.38 | Label Acc: 0.88


 38%|███▊      | 38/100 [03:35<05:49,  5.64s/it]

[22:48:29] Epoch 38/100 Train Loss: 10.75 | Count Acc: 0.57 | Label Acc: 0.91|| Val Loss: 18.18 | Count Acc: 0.43 | Label Acc: 0.88


 39%|███▉      | 39/100 [03:41<05:44,  5.65s/it]

[22:48:34] Epoch 39/100 Train Loss: 10.38 | Count Acc: 0.59 | Label Acc: 0.91|| Val Loss: 15.37 | Count Acc: 0.46 | Label Acc: 0.88


 40%|████      | 40/100 [03:47<05:38,  5.65s/it]

[22:48:40] Epoch 40/100 Train Loss: 10.37 | Count Acc: 0.59 | Label Acc: 0.91|| Val Loss: 16.71 | Count Acc: 0.45 | Label Acc: 0.88


 41%|████      | 41/100 [03:52<05:33,  5.65s/it]

[22:48:46] Epoch 41/100 Train Loss: 10.26 | Count Acc: 0.59 | Label Acc: 0.91|| Val Loss: 24.75 | Count Acc: 0.36 | Label Acc: 0.88


 42%|████▏     | 42/100 [03:58<05:28,  5.66s/it]

[22:48:51] Epoch 42/100 Train Loss: 9.98 | Count Acc: 0.61 | Label Acc: 0.91|| Val Loss: 19.56 | Count Acc: 0.42 | Label Acc: 0.88


 43%|████▎     | 43/100 [04:04<05:22,  5.66s/it]

[22:48:57] Epoch 43/100 Train Loss: 10.02 | Count Acc: 0.61 | Label Acc: 0.91|| Val Loss: 17.70 | Count Acc: 0.36 | Label Acc: 0.88


 44%|████▍     | 44/100 [04:09<05:17,  5.67s/it]

[22:49:03] Epoch 44/100 Train Loss: 9.58 | Count Acc: 0.64 | Label Acc: 0.91|| Val Loss: 28.04 | Count Acc: 0.26 | Label Acc: 0.88


 45%|████▌     | 45/100 [04:15<05:12,  5.69s/it]

[22:49:08] Epoch 45/100 Train Loss: 9.42 | Count Acc: 0.64 | Label Acc: 0.91|| Val Loss: 18.97 | Count Acc: 0.42 | Label Acc: 0.88


 46%|████▌     | 46/100 [04:21<05:06,  5.67s/it]

[22:49:14] Epoch 46/100 Train Loss: 9.19 | Count Acc: 0.65 | Label Acc: 0.91|| Val Loss: 19.52 | Count Acc: 0.40 | Label Acc: 0.88


 47%|████▋     | 47/100 [04:26<05:00,  5.67s/it]

[22:49:20] Epoch 47/100 Train Loss: 8.86 | Count Acc: 0.66 | Label Acc: 0.91|| Val Loss: 15.84 | Count Acc: 0.47 | Label Acc: 0.88


 48%|████▊     | 48/100 [04:32<04:55,  5.68s/it]

[22:49:25] Epoch 48/100 Train Loss: 8.82 | Count Acc: 0.66 | Label Acc: 0.91|| Val Loss: 15.49 | Count Acc: 0.42 | Label Acc: 0.88


 49%|████▉     | 49/100 [04:38<04:49,  5.68s/it]

[22:49:31] Epoch 49/100 Train Loss: 8.22 | Count Acc: 0.70 | Label Acc: 0.91|| Val Loss: 16.39 | Count Acc: 0.41 | Label Acc: 0.88


 50%|█████     | 50/100 [04:44<04:44,  5.70s/it]

[22:49:37] Epoch 50/100 Train Loss: 7.93 | Count Acc: 0.71 | Label Acc: 0.91|| Val Loss: 15.25 | Count Acc: 0.51 | Label Acc: 0.88


 51%|█████     | 51/100 [04:49<04:39,  5.70s/it]

[22:49:42] Epoch 51/100 Train Loss: 7.13 | Count Acc: 0.75 | Label Acc: 0.91|| Val Loss: 28.96 | Count Acc: 0.31 | Label Acc: 0.88


 52%|█████▏    | 52/100 [04:55<04:33,  5.70s/it]

[22:49:48] Epoch 52/100 Train Loss: 7.10 | Count Acc: 0.73 | Label Acc: 0.91|| Val Loss: 19.20 | Count Acc: 0.41 | Label Acc: 0.88


 53%|█████▎    | 53/100 [05:01<04:27,  5.69s/it]

[22:49:54] Epoch 53/100 Train Loss: 6.81 | Count Acc: 0.76 | Label Acc: 0.91|| Val Loss: 16.74 | Count Acc: 0.45 | Label Acc: 0.88


 54%|█████▍    | 54/100 [05:06<04:20,  5.67s/it]

[22:49:59] Epoch 54/100 Train Loss: 6.56 | Count Acc: 0.76 | Label Acc: 0.92|| Val Loss: 19.38 | Count Acc: 0.47 | Label Acc: 0.88


 55%|█████▌    | 55/100 [05:12<04:14,  5.66s/it]

[22:50:05] Epoch 55/100 Train Loss: 5.55 | Count Acc: 0.79 | Label Acc: 0.92|| Val Loss: 153.27 | Count Acc: 0.17 | Label Acc: 0.88


 56%|█████▌    | 56/100 [05:18<04:08,  5.66s/it]

[22:50:11] Epoch 56/100 Train Loss: 5.78 | Count Acc: 0.79 | Label Acc: 0.91|| Val Loss: 20.24 | Count Acc: 0.51 | Label Acc: 0.88


 57%|█████▋    | 57/100 [05:23<04:03,  5.66s/it]

[22:50:16] Epoch 57/100 Train Loss: 5.21 | Count Acc: 0.81 | Label Acc: 0.92|| Val Loss: 31.54 | Count Acc: 0.43 | Label Acc: 0.88


 58%|█████▊    | 58/100 [05:29<03:57,  5.66s/it]

[22:50:22] Epoch 58/100 Train Loss: 5.18 | Count Acc: 0.82 | Label Acc: 0.91|| Val Loss: 24.07 | Count Acc: 0.38 | Label Acc: 0.88


 59%|█████▉    | 59/100 [05:35<03:52,  5.68s/it]

[22:50:28] Epoch 59/100 Train Loss: 4.75 | Count Acc: 0.84 | Label Acc: 0.92|| Val Loss: 19.38 | Count Acc: 0.53 | Label Acc: 0.88


 60%|██████    | 60/100 [05:40<03:47,  5.68s/it]

[22:50:33] Epoch 60/100 Train Loss: 3.83 | Count Acc: 0.86 | Label Acc: 0.92|| Val Loss: 23.26 | Count Acc: 0.49 | Label Acc: 0.88


 61%|██████    | 61/100 [05:46<03:41,  5.67s/it]

[22:50:39] Epoch 61/100 Train Loss: 3.43 | Count Acc: 0.88 | Label Acc: 0.92|| Val Loss: 123.06 | Count Acc: 0.19 | Label Acc: 0.88


 62%|██████▏   | 62/100 [05:52<03:35,  5.66s/it]

[22:50:45] Epoch 62/100 Train Loss: 3.28 | Count Acc: 0.89 | Label Acc: 0.92|| Val Loss: 26.80 | Count Acc: 0.39 | Label Acc: 0.88


 63%|██████▎   | 63/100 [05:57<03:29,  5.65s/it]

[22:50:50] Epoch 63/100 Train Loss: 3.63 | Count Acc: 0.87 | Label Acc: 0.92|| Val Loss: 26.89 | Count Acc: 0.46 | Label Acc: 0.88


 64%|██████▍   | 64/100 [06:03<03:23,  5.65s/it]

[22:50:56] Epoch 64/100 Train Loss: 2.74 | Count Acc: 0.91 | Label Acc: 0.92|| Val Loss: 24.17 | Count Acc: 0.52 | Label Acc: 0.88


 65%|██████▌   | 65/100 [06:08<03:17,  5.65s/it]

[22:51:02] Epoch 65/100 Train Loss: 2.99 | Count Acc: 0.89 | Label Acc: 0.92|| Val Loss: 25.59 | Count Acc: 0.54 | Label Acc: 0.88


 66%|██████▌   | 66/100 [06:14<03:11,  5.64s/it]

[22:51:07] Epoch 66/100 Train Loss: 2.09 | Count Acc: 0.92 | Label Acc: 0.92|| Val Loss: 55.71 | Count Acc: 0.38 | Label Acc: 0.88


 67%|██████▋   | 67/100 [06:20<03:06,  5.65s/it]

[22:51:13] Epoch 67/100 Train Loss: 2.29 | Count Acc: 0.93 | Label Acc: 0.92|| Val Loss: 24.35 | Count Acc: 0.40 | Label Acc: 0.88


 68%|██████▊   | 68/100 [06:25<03:01,  5.66s/it]

[22:51:19] Epoch 68/100 Train Loss: 2.00 | Count Acc: 0.93 | Label Acc: 0.92|| Val Loss: 34.10 | Count Acc: 0.42 | Label Acc: 0.88


 69%|██████▉   | 69/100 [06:31<02:55,  5.66s/it]

[22:51:24] Epoch 69/100 Train Loss: 1.88 | Count Acc: 0.93 | Label Acc: 0.92|| Val Loss: 29.29 | Count Acc: 0.41 | Label Acc: 0.88


 70%|███████   | 70/100 [06:37<02:49,  5.67s/it]

[22:51:30] Epoch 70/100 Train Loss: 2.13 | Count Acc: 0.93 | Label Acc: 0.92|| Val Loss: 33.18 | Count Acc: 0.44 | Label Acc: 0.88


 71%|███████   | 71/100 [06:42<02:44,  5.68s/it]

[22:51:36] Epoch 71/100 Train Loss: 2.33 | Count Acc: 0.93 | Label Acc: 0.92|| Val Loss: 25.85 | Count Acc: 0.38 | Label Acc: 0.88


 72%|███████▏  | 72/100 [06:48<02:38,  5.67s/it]

[22:51:41] Epoch 72/100 Train Loss: 1.54 | Count Acc: 0.95 | Label Acc: 0.92|| Val Loss: 58.41 | Count Acc: 0.38 | Label Acc: 0.88


 73%|███████▎  | 73/100 [06:54<02:33,  5.67s/it]

[22:51:47] Epoch 73/100 Train Loss: 1.48 | Count Acc: 0.95 | Label Acc: 0.92|| Val Loss: 39.08 | Count Acc: 0.41 | Label Acc: 0.88


 74%|███████▍  | 74/100 [06:59<02:27,  5.66s/it]

[22:51:53] Epoch 74/100 Train Loss: 1.07 | Count Acc: 0.97 | Label Acc: 0.93|| Val Loss: 30.20 | Count Acc: 0.43 | Label Acc: 0.88


 75%|███████▌  | 75/100 [07:05<02:21,  5.65s/it]

[22:51:58] Epoch 75/100 Train Loss: 0.86 | Count Acc: 0.98 | Label Acc: 0.93|| Val Loss: 34.57 | Count Acc: 0.40 | Label Acc: 0.88


 76%|███████▌  | 76/100 [07:11<02:15,  5.65s/it]

[22:52:04] Epoch 76/100 Train Loss: 1.83 | Count Acc: 0.94 | Label Acc: 0.92|| Val Loss: 30.93 | Count Acc: 0.48 | Label Acc: 0.88


 77%|███████▋  | 77/100 [07:16<02:09,  5.65s/it]

[22:52:10] Epoch 77/100 Train Loss: 1.09 | Count Acc: 0.97 | Label Acc: 0.93|| Val Loss: 27.13 | Count Acc: 0.48 | Label Acc: 0.88


 78%|███████▊  | 78/100 [07:22<02:04,  5.66s/it]

[22:52:15] Epoch 78/100 Train Loss: 0.88 | Count Acc: 0.98 | Label Acc: 0.93|| Val Loss: 43.75 | Count Acc: 0.46 | Label Acc: 0.88


 79%|███████▉  | 79/100 [07:28<01:59,  5.67s/it]

[22:52:21] Epoch 79/100 Train Loss: 1.26 | Count Acc: 0.97 | Label Acc: 0.93|| Val Loss: 31.91 | Count Acc: 0.46 | Label Acc: 0.88


 80%|████████  | 80/100 [07:33<01:53,  5.66s/it]

[22:52:27] Epoch 80/100 Train Loss: 1.75 | Count Acc: 0.94 | Label Acc: 0.92|| Val Loss: 32.82 | Count Acc: 0.44 | Label Acc: 0.88


 81%|████████  | 81/100 [07:39<01:47,  5.65s/it]

[22:52:32] Epoch 81/100 Train Loss: 2.04 | Count Acc: 0.93 | Label Acc: 0.92|| Val Loss: 96.27 | Count Acc: 0.19 | Label Acc: 0.88


 82%|████████▏ | 82/100 [07:45<01:41,  5.64s/it]

[22:52:38] Epoch 82/100 Train Loss: 0.91 | Count Acc: 0.97 | Label Acc: 0.93|| Val Loss: 39.97 | Count Acc: 0.43 | Label Acc: 0.88


 83%|████████▎ | 83/100 [07:50<01:35,  5.64s/it]

[22:52:43] Epoch 83/100 Train Loss: 0.54 | Count Acc: 0.99 | Label Acc: 0.93|| Val Loss: 40.81 | Count Acc: 0.43 | Label Acc: 0.88


 84%|████████▍ | 84/100 [07:56<01:30,  5.67s/it]

[22:52:49] Epoch 84/100 Train Loss: 0.22 | Count Acc: 1.00 | Label Acc: 0.93|| Val Loss: 34.83 | Count Acc: 0.47 | Label Acc: 0.88


 85%|████████▌ | 85/100 [08:02<01:24,  5.66s/it]

[22:52:55] Epoch 85/100 Train Loss: 0.49 | Count Acc: 0.99 | Label Acc: 0.93|| Val Loss: 34.35 | Count Acc: 0.42 | Label Acc: 0.88


 86%|████████▌ | 86/100 [08:07<01:19,  5.66s/it]

[22:53:00] Epoch 86/100 Train Loss: 0.97 | Count Acc: 0.98 | Label Acc: 0.93|| Val Loss: 30.47 | Count Acc: 0.46 | Label Acc: 0.88


 87%|████████▋ | 87/100 [08:13<01:13,  5.66s/it]

[22:53:06] Epoch 87/100 Train Loss: 2.40 | Count Acc: 0.93 | Label Acc: 0.92|| Val Loss: 30.55 | Count Acc: 0.45 | Label Acc: 0.88


 88%|████████▊ | 88/100 [08:19<01:07,  5.66s/it]

[22:53:12] Epoch 88/100 Train Loss: 2.13 | Count Acc: 0.93 | Label Acc: 0.92|| Val Loss: 42.60 | Count Acc: 0.40 | Label Acc: 0.88


 89%|████████▉ | 89/100 [08:24<01:02,  5.67s/it]

[22:53:17] Epoch 89/100 Train Loss: 0.53 | Count Acc: 0.99 | Label Acc: 0.93|| Val Loss: 30.61 | Count Acc: 0.49 | Label Acc: 0.88


 90%|█████████ | 90/100 [08:30<00:56,  5.68s/it]

[22:53:23] Epoch 90/100 Train Loss: 0.21 | Count Acc: 1.00 | Label Acc: 0.93|| Val Loss: 33.61 | Count Acc: 0.49 | Label Acc: 0.88


 91%|█████████ | 91/100 [08:36<00:51,  5.67s/it]

[22:53:29] Epoch 91/100 Train Loss: 0.20 | Count Acc: 1.00 | Label Acc: 0.93|| Val Loss: 37.34 | Count Acc: 0.48 | Label Acc: 0.88


 92%|█████████▏| 92/100 [08:41<00:45,  5.66s/it]

[22:53:34] Epoch 92/100 Train Loss: 0.11 | Count Acc: 1.00 | Label Acc: 0.93|| Val Loss: 35.52 | Count Acc: 0.53 | Label Acc: 0.88


 93%|█████████▎| 93/100 [08:47<00:39,  5.65s/it]

[22:53:40] Epoch 93/100 Train Loss: 0.07 | Count Acc: 1.00 | Label Acc: 0.93|| Val Loss: 36.07 | Count Acc: 0.53 | Label Acc: 0.88


 94%|█████████▍| 94/100 [08:53<00:33,  5.66s/it]

[22:53:46] Epoch 94/100 Train Loss: 0.06 | Count Acc: 1.00 | Label Acc: 0.93|| Val Loss: 37.58 | Count Acc: 0.53 | Label Acc: 0.88


 95%|█████████▌| 95/100 [08:58<00:28,  5.65s/it]

[22:53:51] Epoch 95/100 Train Loss: 0.06 | Count Acc: 1.00 | Label Acc: 0.93|| Val Loss: 35.65 | Count Acc: 0.53 | Label Acc: 0.88


 96%|█████████▌| 96/100 [09:04<00:22,  5.66s/it]

[22:53:57] Epoch 96/100 Train Loss: 0.05 | Count Acc: 1.00 | Label Acc: 0.93|| Val Loss: 37.53 | Count Acc: 0.52 | Label Acc: 0.88


 97%|█████████▋| 97/100 [09:10<00:16,  5.65s/it]

[22:54:03] Epoch 97/100 Train Loss: 0.05 | Count Acc: 1.00 | Label Acc: 0.93|| Val Loss: 39.65 | Count Acc: 0.49 | Label Acc: 0.88


 98%|█████████▊| 98/100 [09:15<00:11,  5.66s/it]

[22:54:08] Epoch 98/100 Train Loss: 0.07 | Count Acc: 1.00 | Label Acc: 0.93|| Val Loss: 47.53 | Count Acc: 0.43 | Label Acc: 0.88


 99%|█████████▉| 99/100 [09:21<00:05,  5.65s/it]

[22:54:14] Epoch 99/100 Train Loss: 7.69 | Count Acc: 0.81 | Label Acc: 0.91|| Val Loss: 22.08 | Count Acc: 0.39 | Label Acc: 0.88


100%|██████████| 100/100 [09:27<00:00,  5.67s/it]

[22:54:20] Epoch 100/100 Train Loss: 1.10 | Count Acc: 0.97 | Label Acc: 0.93|| Val Loss: 35.72 | Count Acc: 0.42 | Label Acc: 0.88





In [15]:
 with torch.no_grad():
    for images, labels, counts in val_loader:
        images = images.to(device).float()
        labels = labels.to(device).float()
        counts = counts.to(device).float().unsqueeze(1)
        count_output, label_output = model(images)
        loss = model.loss(count_output, label_output, counts, labels)
        val_loss += loss.item()
        preds = torch.round(count_output).flatten()
        val_correct += (preds == counts.flatten()).sum().item()
        val_label_correct += (torch.round(label_output).flatten() == labels.flatten()).sum().item()
        val_total += labels.size(0)
        print(label_output)

RuntimeError: one_hot is only applicable to index tensor of type LongTensor.

In [None]:
freq = filtered_data["labels"].sum(dim = 0)
freq = freq / freq.sum()
print(freq)