Skip to content
Permalink
Browse files

fix code issues.

  • Loading branch information
lijunzh committed Apr 22, 2019
1 parent edcb6fa commit 4bb3898b1d4f9efbcdc095b9a164f1b2fbd69620
@@ -1,12 +1,11 @@
import numpy as np

from torch.nn import CrossEntropyLoss
from torch.utils.data import random_split, DataLoader
from torch.utils.data import DataLoader
from torch.utils.data import random_split

import yews.transforms as transforms
from yews.datasets import DatasetArray
from yews.train import Trainer
from yews.models import Cpic
from yews.train import Trainer


if __name__ == '__main__':
@@ -1,4 +1,4 @@
from setuptools import setup, find_packages
from setuptools import setup

requirements = [
'numpy',
@@ -9,6 +9,11 @@ def is_dataset(obj):
"""
return getattr(obj, '__getitem__', None) and getattr(obj, '__len__', None)

def _format_transform_repr(transform, head):
lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]])


class BaseDataset(data.Dataset):
"""An abstract class representing a Dataset.
@@ -93,19 +98,14 @@ def __repr__(self):
body.append("Root location: {}".format(self.root))
body += self.extra_repr().splitlines()
if self.sample_transform is not None:
body += self._format_transform_repr(self.sample_transform,
"Sample transforms: ")
body += _format_transform_repr(self.sample_transform,
"Sample transforms: ")
if self.target_transform is not None:
body += self._format_transform_repr(self.target_transform,
"Target transforms: ")
body += _format_transform_repr(self.target_transform,
"Target transforms: ")
lines = [head] + [" " * self._repr_indent + line for line in body]
return '\n'.join(lines)

def _format_transform_repr(self, transform, head):
lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]])

def extra_repr(self):
return ""

@@ -140,5 +140,5 @@ def is_valid(self):

return self.root.exists()

def handle_invalid(self, **kwargs):
def handle_invalid(self):
raise ValueError(f"{self.root} is not a valid path.")
@@ -1,5 +1,3 @@
from pathlib import Path

import numpy as np

from .base import PathDataset
@@ -20,14 +20,14 @@ class Wenchuan(DatasetArrayFolder):

url = URL('https://www.dropbox.com/s/enr75zt2ukx118r/wenchuan.tar.bz2?dl=1')

def __init__(self, path, download=False, **kwargs):
def __init__(self, download=False, **kwargs):
# verify download flag
if type(download) is not bool:
if not isinstance(download, bool):
raise ValueError("`download` needs to be True or False.")

# verify if dataset is ready
try:
super().__init__(path=path, **kwargs)
super().__init__(**kwargs)
except ValueError:
if download:
# download compressed file from source if not exists
@@ -38,7 +38,7 @@ def __init__(self, path, download=False, **kwargs):
print("Extracting dataset ...")
extract_bz2(fpath, self.root)
# try initiate DatasetArrayFolder again
super().__init__(path=path, **kwargs)
super().__init__(**kwargs)
else:
raise ValueError(f"{self.root} contains no valid dataset. "
f"Consider set `download=True` and remove broken bz2 file.")
@@ -1,7 +1,5 @@
# TO-DO: need to add model_zoo utility and pretrained models.
import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = [
'Cpic',
@@ -1,5 +1,3 @@
import numpy as np
import torch
from torch import optim

from . import functional as F
@@ -118,4 +116,3 @@ def train(self, train_loader, val_loader, epochs=100, print_freq=None):
is_best = self.val_acc[-1] > self.best_acc
self.best_acc = max(self.val_acc[-1], self.best_acc)
print(f"Training fisihed. Best accuracy is {self.best_acc}")

@@ -1,5 +1,3 @@
import numpy as np

__all__ = [
"MovingAverageMeter",
"ExponentialMovingAverageMeter",
@@ -19,10 +17,10 @@ def __init__(self):
self.avg = None
self.count = 0

def reset(self, *args):
def reset(self):
raise NotImplementedError

def update(self, *args):
def update(self):
raise NotImplementedError


@@ -73,4 +71,3 @@ def update(self, val, n=1):

self.val = val
self.count += n

@@ -1,5 +1,5 @@
from .base import BaseTransform
from . import functional as F
from .base import BaseTransform

__all__ = [
"ToTensor",
@@ -26,11 +26,11 @@ class ToInt(BaseTransform):
"""

def __init__(self, lookup):
if type(lookup) is dict:
if isinstance(lookup, dict):
self.lookup = lookup
else:
raise ValueError("Lookup table needs to be a dictionary.")
if any([type(val) is not int for val in self.lookup.values()]):
if any([not isinstance(val, int) for val in self.lookup.values()]):
raise ValueError("Values of the lookup table need to be Int.")

def __call__(self, label):
@@ -59,4 +59,3 @@ def __init__(self, samplestart, sampleend):

def __call__(self, wav):
return wav[:, self.start:self.end]

0 comments on commit 4bb3898

Please sign in to comment.
You can’t perform that action at this time.