Skip to content

Commit

Permalink
add annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Nov 13, 2018
1 parent c3359a2 commit e4f1888
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion jdit/assessment/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def forward(self, inp):

# ______________________________________________________________

def compute_act_statistics_from_loader(dataloader, model, gpu_ids):
def compute_act_statistics_from_loader(dataloader:DataLoader, model, gpu_ids):
"""
:param dataloader:
Expand Down
8 changes: 4 additions & 4 deletions jdit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import save, load
from typing import Union
from collections import OrderedDict

from types import FunctionType
class _cached_property(object):
"""
Decorator that converts a method with a single self argument into a
Expand Down Expand Up @@ -89,7 +89,7 @@ class Model(object):
"""

def __init__(self, proto_model: Module = Module, gpu_ids_abs: Union[list,tuple]=(), init_method: [str, function] = "kaiming",
def __init__(self, proto_model: Module = Module, gpu_ids_abs: Union[list,tuple]=(), init_method: [str, FunctionType] = "kaiming",
show_structure=False, verbose=True):
if not gpu_ids_abs:
gpu_ids_abs = []
Expand All @@ -109,7 +109,7 @@ def __call__(self, *args, **kwargs):
def __getattr__(self, item):
return getattr(self.model, item)

def define(self, proto_model: Module, gpu_ids: Union[list, tuple], init_method: Union[str, function],
def define(self, proto_model: Module, gpu_ids: Union[list, tuple], init_method: Union[str, FunctionType],
show_structure: bool):
"""Define and wrap a pytorch module, according to CPU, GPU and multi-GPUs.
Expand Down Expand Up @@ -298,7 +298,7 @@ def count_params(self, proto_model: Module):
# else:
# return self.count_params(self.model)

def _apply_weight_init(self, init_method: Union[str, function], proto_model: Module):
def _apply_weight_init(self, init_method: Union[str, FunctionType], proto_model: Module):
init_name = "No"
if init_method:
if init_method == 'kaiming':
Expand Down
2 changes: 1 addition & 1 deletion jdit/parallel/parallel_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# coding=utf-8
from abc import abstractmethod
from multiprocessing import Pool

from types import FunctionType

class SupParallelTrainer(object):
""" Training parallel.
Expand Down
3 changes: 2 additions & 1 deletion jdit/trainer/super.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd
import numpy as np
from typing import Union
from types import FunctionType

from jdit.optimizer import Optimizer
from jdit.model import Model
Expand Down Expand Up @@ -106,7 +107,7 @@ def get_data_from_batch(self, batch_data: list, device: torch.device):
input, ground_truth = batch_data[0], batch_data[1]
return input.to(device), ground_truth.to(device)

def train_iteration(self, opt: Optimizer, compute_loss_fc: function, tag: str = "Train"):
def train_iteration(self, opt: Optimizer, compute_loss_fc: FunctionType, tag: str = "Train"):
opt.zero_grad()
loss, var_dic = compute_loss_fc()
loss.backward()
Expand Down

0 comments on commit e4f1888

Please sign in to comment.