In [1]:

import torch
import torch.nn as nn

In [2]:
class Wrapper(nn.Module):
    '''
    Object wrapper class.
    This a wrapper for objects. It is initialiesed with the object to wrap
    and then proxies the unhandled getattribute methods to it.
    Other classes are to inherit from it.
    '''
    def __init__(self, obj):
        '''
        Wrapper constructor.
        @param obj: object to wrap
        '''
        # wrap the object
        super(Wrapper, self).__init__()
        self.__class__ = type(obj)
        self._wrapped_obj = obj

    def __getattr__(self, attr):
        # see if this object has attr
        # NOTE do not use hasattr, it goes into
        # infinite recurrsion
        if attr in self.__dict__:
            # this object has it
            return getattr(self, attr)
        # proxy to the wrapped object
        return getattr(self._wrapped_obj, attr)
    
    def __call__(self, *args, **kwargs):
        return self._wrapped_obj.__call__(*args, **kwargs)
    
    

In [4]:
import copy
import logging
from abc import ABC, abstractmethod
from collections import OrderedDict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (Any, Dict, List, NamedTuple, Optional, Tuple, Type,
                    TypeVar, Union)

import torch
from simple_parsing import MutableField as mutable_field
from simple_parsing import choice, field
from torch import Tensor, nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import models
from torchvision.utils import save_image

from common.layers import ConvBlock, Flatten
from common.losses import LossInfo
from common.metrics import accuracy, get_metrics
from config import Config
from tasks import AuxiliaryTask, AuxiliaryTaskOptions, Tasks
from utils.utils import fix_channels

logger = logging.getLogger(__file__)

class Classifier(nn.Module):
    @dataclass
    class HParams:
        """ Set of hyperparameters for the classifier.

        We use [simple_parsing](www.github.com/lebrice/simpleparsing) to
        generate command-line arguments for each attribute of this class.
        """
        batch_size: int = 128   # Input batch size for training.
        epochs: int = 10        # Number of epochs to train.
        learning_rate: float = field(default=1e-3, alias="-lr")  # learning rate.

        # Dimensions of the hidden state (feature extractor/encoder output).
        hidden_size: int = 100

        # Prevent gradients of the classifier from backpropagating into the encoder.
        detach_classifier: bool = False

        # Use an encoder architecture from the torchvision.models package.
        encoder_model: Optional[str] = choice({
            "vgg16": models.vgg16,  # This is the only one tested so far.
            "resnet18": models.resnet18,
            "resnet34": models.resnet34,
            "resnet50": models.resnet50,
            "resnet101": models.resnet101,
            "resnet152": models.resnet152,
            "alexnet": models.alexnet,
            # "squeezenet": models.squeezenet1_0,  # Not supported yet (weird output shape)
            "densenet": models.densenet161,
            # "inception": models.inception_v3,  # Not supported yet (creating model takes forever?)
            # "googlenet": models.googlenet,  # Not supported yet (creating model takes forever?)
            "shufflenet": models.shufflenet_v2_x1_0,
            "mobilenet": models.mobilenet_v2,
            "resnext50_32x4d": models.resnext50_32x4d,
            "wide_resnet50_2": models.wide_resnet50_2,
            "mnasnet": models.mnasnet1_0,
        }, default=None)
        # Use the pretrained weights of the ImageNet model from torchvision.
        pretrained_model: bool = False
        # Freeze the weights of the pretrained encoder (except the last layer,
        # which projects from their hidden size to ours).
        freeze_pretrained_model: bool = False


        aux_tasks: AuxiliaryTaskOptions = field(default_factory=AuxiliaryTaskOptions)

    def __init__(self,
                 input_shape: Tuple[int, ...],
                 num_classes: int,
                 encoder: nn.Module,
                 classifier: nn.Module,
                #  auxiliary_task_options: AuxiliaryTaskOptions,
                 hparams: HParams,
                 config: Config):
        super().__init__()
        self.input_shape = input_shape
        self.num_classes = num_classes
        # Feature extractor
        self.encoder = encoder
        # Classifier output layer
        self.classifier = classifier
        self.hparams: Classifier.HParams = hparams
        self.config = config

        self.hidden_size = hparams.hidden_size  
        self.classification_loss = nn.CrossEntropyLoss()
        self.device = self.config.device

        # Share the relevant parameters with all the auxiliary tasks.
        # We do this by setting class attributes.
        AuxiliaryTask.hidden_size   = self.hparams.hidden_size
        AuxiliaryTask.input_shape   = self.input_shape
        AuxiliaryTask.encoder       = self.encoder
        AuxiliaryTask.classifier    = self.classifier
        AuxiliaryTask.preprocessing = self.preprocess_inputs
        
        # Dictionary of auxiliary tasks.
        self.tasks: Dict[str, AuxiliaryTask] = self.hparams.aux_tasks.create_tasks(  # type: ignore
            input_shape=input_shape,
            hidden_size=self.hparams.hidden_size
        )

        # Current task label. (Optional, as we shouldn't rely on this.)
        # TODO: Replace the classifier model with something like CN-DPM or CURL,
        # so we can actually do task-free CL.
        self._current_task_id: Optional[str] = None
        # Dictionary of classifiers to use if we are provided the task-label.
        self.task_classifiers: Dict[str, nn.Module] = nn.ModuleDict()  #type: ignore  

        if self.config.debug and self.config.verbose:
            logger.debug(self)
            logger.debug("Auxiliary tasks:")
            for task_name, task in self.tasks.items():
                logger.debug(f"{task.name}: {task.coefficient}")

        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)  

    def supervised_loss(self, x: Tensor, y: Tensor, h_x: Tensor=None, y_pred: Tensor=None) -> LossInfo:
        h_x = self.encode(x) if h_x is None else h_x
        y_pred = self.logits(h_x) if y_pred is None else y_pred
        y = y.view(-1)
        loss = self.classification_loss(y_pred, y)
        metrics = get_metrics(x=x, h_x=h_x, y_pred=y_pred, y=y)
        loss_info = LossInfo(
            name=Tasks.SUPERVISED,
            total_loss=loss,
            tensors=(dict(x=x, h_x=h_x, y_pred=y_pred, y=y)),
        )
        loss_info.metrics[Tasks.SUPERVISED] = metrics
        return loss_info

    def get_loss(self, x: Tensor, y: Tensor=None) -> LossInfo:
        total_loss = LossInfo("Train" if self.training else "Test")
        h_x = self.encode(x)
        y_pred = self.logits(h_x)
        
        total_loss.total_loss = torch.zeros(1, device=self.device)
        total_loss.tensors["x"] = x.detach()
        total_loss.tensors["h_x"] = h_x.detach()
        total_loss.tensors["y_pred"] = y_pred.detach()

        if y is not None:
            supervised_loss = self.supervised_loss(x=x, y=y, h_x=h_x, y_pred=y_pred)
            total_loss += supervised_loss

        for task_name, aux_task in self.tasks.items():
            if aux_task.enabled:
                aux_task_loss = aux_task.get_scaled_loss(x, h_x=h_x, y_pred=y_pred, y=y)
                total_loss += aux_task_loss
        
        if self.config.debug and self.config.verbose:
            for name, loss in total_loss.losses.items():
                logger.debug(name, loss.total_loss, loss.metrics)

        return total_loss

    def encode(self, x: Tensor):
        x = self.preprocess_inputs(x)
        return self.encoder(x)

    def preprocess_inputs(self, x: Tensor) -> Tensor:
        """Preprocess the input tensor x before it is passed to the encoder.
        
        By default this does nothing. When subclassing the Classifier or 
        switching datasets, you might want to change this behaviour.

        Parameters
        ----------
        - x : Tensor
        
            a batch of inputs.
        
        Returns
        -------
        Tensor
            The preprocessed inputs.
        """
        return fix_channels(x)

    @property
    def current_task_id(self) -> Optional[str]:
        if self._current_task_id is None:
            return None
        return self._current_task_id

    @current_task_id.setter
    def current_task_id(self, value: Optional[Union[int, str]]):
        value = str(value) if value is not None else None
        self._current_task_id = value
        # If there isn't a classifier for this task
        if value and value not in self.task_classifiers.keys():
            if self.config.debug:
                logger.info(f"Creating a new classifier for taskid {value}.")
            # Create one starting from the "global" classifier.
            classifier = copy.deepcopy(self.classifier)
            self.task_classifiers[value] = classifier
            self.optimizer.add_param_group({"params": classifier.parameters()})

    def logits(self, h_x: Tensor) -> Tensor:
        if self.hparams.detach_classifier:
            h_x = h_x.detach()

        # Use the "general" classifier by default.
        classifier = self.classifier
        # If a task-id is given, use the task-specific classifier.
        if self.current_task_id is not None:
            classifier = self.task_classifiers[self.current_task_id]
        return classifier(h_x)
    
    def load_state_dict(self, state_dict: Dict) -> Tuple[List[str], List[str]]:
        current_task_id = self.current_task_id
        for key in state_dict:
            if key.startswith("task_classifiers"):
                n = key.split(".")[1]
                self.current_task_id = n
        return super().load_state_dict(state_dict)


Couldn't import the modules from the falr submodule: cannot import name 'HParams' from 'config' (/Users/oleksostapenko/Projects/SSCL/config.py)
Make sure to run `git submodule init; git submodule update`


NameError: name 'exit' is not defined

In [48]:
c = C('Geeks')

In [49]:
cc = Wrapper(c)

In [51]:
cc("gg")

gg


In [54]:
cc.name = 'Other'

In [55]:
cc.name

'Other'

In [12]:
c.print_

RecursionError: maximum recursion depth exceeded