Skip to content

Commit

Permalink
Merge pull request #25 from decile-team/doc_plots
Browse files Browse the repository at this point in the history
Doc Index Plots and Utils Docstrings
  • Loading branch information
durgas16 committed May 7, 2021
2 parents 77813ae + 3d8767b commit aa001b6
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 14 deletions.
30 changes: 29 additions & 1 deletion distil/utils/calculate_class_budgets.py
Expand Up @@ -3,6 +3,34 @@
import torch

def calculate_class_budgets(budget, num_classes, trn_lbls, N_trn):

"""
Calculates a list of class budgets whose sum is that of the specified budget.
Furthermore, each budget calculated for a class is based on the proportion
of labels of that class that appear in trn_lbls. For example, if trn_lbls has
50% "0" labels, then the budget calculated for class "0" will be 50% of the full
budget.
Specifically, this function makes sure to at least give every class a budget of 1.
If this violates the full budget (the sum of the per-class budgets is greater than
the full budget), then random class budgets are set to 0 until the budget constraint
is satisfied.
If the budget constraint is not broken, then it awards the rest of the full budget
in the proportion described above in a best-attempt manner.
Parameters
----------
budget: int
Full budget to split into class budgets
num_classes: int
Number of per-class budgets to calculate
trn_lbls: Torch tensor
Label tensor on which to base per-class budgets
N_trn: int
Number of labels in trn_lbls
"""

# Tabulate class populations
class_pops = list()
for i in range(num_classes):
Expand Down Expand Up @@ -72,4 +100,4 @@ def calculate_class_budgets(budget, num_classes, trn_lbls, N_trn):
for i in range(num_classes):
class_budgets[i] = floored_class_budgets[i][1]

return class_budgets
return class_budgets
33 changes: 23 additions & 10 deletions distil/utils/config_helper.py
Expand Up @@ -2,15 +2,28 @@
import os

def read_config_file(filename):

print(filename.split('.')[-1])
if filename.split('.')[-1] not in ['json']:
raise IOError('Only json type are supported now!')

if not os.path.exists(filename):
raise FileNotFoundError('Config file does not exist!')

with open(filename, 'r') as f:
config = json.load(f)
"""
Loads and returns a configuration from the supplied filename / path.
Parameters
----------
filename: string
The name/path of the config file to load.
Returns
----------
config: object
The resulting configuration laoded from the JSON file
"""

print(filename.split('.')[-1])
if filename.split('.')[-1] not in ['json']:
raise IOError('Only json type are supported now!')

if not os.path.exists(filename):
raise FileNotFoundError('Config file does not exist!')

with open(filename, 'r') as f:
config = json.load(f)

return config
return config
14 changes: 14 additions & 0 deletions distil/utils/data_handler.py
Expand Up @@ -55,6 +55,8 @@ class DataHandler_SVHN(Dataset):
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
use_test_transform: bool
True if the data handler should apply the test transform. Otherwise, the data handler will use the training transform (default: False)
"""

def __init__(self, X, Y=None, select=True, use_test_transform=False):
Expand Down Expand Up @@ -107,6 +109,8 @@ class DataHandler_MNIST(Dataset):
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
use_test_transform: bool
True if the data handler should apply the test transform. Otherwise, the data handler will use the training transform (default: False)
"""

def __init__(self, X, Y=None, select=True, use_test_transform=False):
Expand Down Expand Up @@ -159,6 +163,8 @@ class DataHandler_KMNIST(Dataset):
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
use_test_transform: bool
True if the data handler should apply the test transform. Otherwise, the data handler will use the training transform (default: False)
"""

def __init__(self, X, Y=None, select=True, use_test_transform=False):
Expand Down Expand Up @@ -211,6 +217,8 @@ class DataHandler_FASHION_MNIST(Dataset):
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
use_test_transform: bool
True if the data handler should apply the test transform. Otherwise, the data handler will use the training transform (default: False)
"""

def __init__(self, X, Y=None, select=True, use_test_transform=False):
Expand Down Expand Up @@ -263,6 +271,8 @@ class DataHandler_CIFAR10(Dataset):
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
use_test_transform: bool
True if the data handler should apply the test transform. Otherwise, the data handler will use the training transform (default: False)
"""

def __init__(self, X, Y=None, select=True, use_test_transform=False):
Expand Down Expand Up @@ -315,6 +325,8 @@ class DataHandler_CIFAR100(Dataset):
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
use_test_transform: bool
True if the data handler should apply the test transform. Otherwise, the data handler will use the training transform (default: False)
"""

def __init__(self, X, Y=None, select=True, use_test_transform=False):
Expand Down Expand Up @@ -367,6 +379,8 @@ class DataHandler_STL10(Dataset):
Labels to be loaded (default: None)
select: bool
True if loading data without labels, False otherwise
use_test_transform: bool
True if the data handler should apply the test transform. Otherwise, the data handler will use the training transform (default: False)
"""

def __init__(self, X, Y=None, select=True, use_test_transform=False):
Expand Down
26 changes: 23 additions & 3 deletions docs/source/index.rst
@@ -1,20 +1,40 @@

Welcome to DISTIL's documentation!
==================================
DISTIL:: Deep dIverSified inTeractIve Learning is an efficient and scalable library active learning built on top of pytorch.
DISTIL:: Deep dIverSified inTeractIve Learning is an efficient and scalable active learning library built on top of PyTorch.

**What is DISTIL?**

Distil is a toolkit in PyTorch which provides access to different active learning algorithms. Active Learning (AL) helps in reducing labeling cost and also reduces training time and resources. AL helps in selecting only the required data and experiments show that using only 30% of data for training can reach accuracy levels close to the levels reached when using the entire dataset.
DISTIL is a toolkit in PyTorch which provides access to different active learning algorithms. Active learning (AL) helps in reducing labeling cost and also reduces training time and resources. AL helps in selecting only the required data, and experiments show that using only 30% of the data for training can reach accuracy levels close to the levels reached when using the entire dataset.

**Principles of DISTIL**:

#. Minimal changes to add it to the existing training structure.
#. Independent of the training strategy used.
#. Achieving similar test accuracy with less amount of training data.
#. Huge reduction in labelling cost and time.
#. Huge reduction in labeling cost and time.
#. Access to various active learning strategies with just one line of code.

**Preliminary Results: CIFAR-10**

.. image:: ../../experiment_plots/cifar10_plot_50k.png
:width: 1000px

**Preliminary Results: MNIST**

.. image:: ../../experiment_plots/mnist_plot.png
:width: 1000px

**Preliminary Results: Fashion MNIST**

.. image:: ../../experiment_plots/fmnist_plot.png
:width: 1000px

**Preliminary Results: SVHN**

.. image:: ../../experiment_plots/svhn_plot.png
:width: 1000px

.. toctree::
:maxdepth: 2
:caption: Contents:
Expand Down

0 comments on commit aa001b6

Please sign in to comment.