In [66]:
%load_ext autoreload
%autoreload 2

In [70]:
import sys
import numpy as np
import matplotlib.pyplot as plt

# Pytorch 
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader

# Other custom files
sys.path.append('../')
import modules.Dataset as data_handler
import modules.transforms as transforms
import modules.model as model

In [88]:
from collections import defaultdict

# Dataset

In [71]:
CROP_DIM = 256
data_root = '../data/processed'

## Transforms

In [72]:
data_transforms = torchvision.transforms.Compose([transforms.RandomCropper(CROP_DIM),
                               transforms.LRFlipper(),
                               transforms.Rotator(),
                               transforms.ToTensor(),
                               transforms.Normalizer()
                                ])

## Load data

In [73]:
dataset = data_handler.BeeDataset(data_root, data_transforms)

Loading paths...
Num paths loaded: 10


In [74]:
dataloader = DataLoader(dataset, batch_size=4)

In [83]:
n_classes = 3

In [96]:
frame_avg = defaultdict(list)
for x, y in dataloader:
    for i in range(n_classes):
        avg = len(y[y==i]) / np.prod(y.numpy().shape)
        frame_avg[i].append(avg)

In [97]:
frame_avg

defaultdict(list,
            {0: [0.8559722900390625, 0.8431205749511719, 0.8412933349609375],
             1: [0.13506317138671875, 0.14374160766601562, 0.1365509033203125],
             2: [0.00896453857421875, 0.0131378173828125, 0.02215576171875]})

In [135]:
new_dict = {}
for key, val in frame_avg.items():
    mean_avg = np.mean(val)
    new_dict[key] = mean_avg

In [136]:
new_dict

{0: 0.846795399983724, 1: 0.13845189412434897, 2: 0.014752705891927084}

In [150]:
weights = np.array(list(new_dict.values()))
w = weights[0]
weights = [1-w, w/2, w/2]

sum(weights)

1.0

In [151]:
weights

[0.153204600016276, 0.423397699991862, 0.423397699991862]

In [152]:
np.save('../data/class_weights.npy', weights)

In [141]:
np.sum(weights)

1.0