Skip to content

Commit

Permalink
load pretrained model
Browse files Browse the repository at this point in the history
  • Loading branch information
guochengqian committed Oct 26, 2019
1 parent 9ac881a commit 9123f8c
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 11 deletions.
6 changes: 6 additions & 0 deletions examples/part_sem_seg/opt.py
Expand Up @@ -76,7 +76,13 @@ def __init__(self):

if args.pretrained_model:
if args.pretrained_model[0] != '/':
if args.pretrained_model[0:2] == 'ex':
args.pretrained_model = os.path.join(os.path.dirname(os.path.dirname(dir_path)),
args.pretrained_model)
else:
args.pretrained_model = os.path.join(dir_path, args.pretrained_model)
args.pretrained_model = os.path.join(dir_path, args.pretrained_model)

if not args.ckpt_path:
args.save_path = os.path.join(dir_path, 'checkpoints/ckpts' + '-' + args.post + '-' + args.time)
else:
Expand Down
5 changes: 5 additions & 0 deletions examples/ppi/opt.py
Expand Up @@ -68,6 +68,11 @@ def __init__(self):

if args.pretrained_model:
if args.pretrained_model[0] != '/':
if args.pretrained_model[0:2] == 'ex':
args.pretrained_model = os.path.join(os.path.dirname(os.path.dirname(dir_path)),
args.pretrained_model)
else:
args.pretrained_model = os.path.join(dir_path, args.pretrained_model)
args.pretrained_model = os.path.join(dir_path, args.pretrained_model)

if not args.ckpt_path:
Expand Down
4 changes: 2 additions & 2 deletions examples/sem_seg_dense/README.md
Expand Up @@ -29,14 +29,14 @@ Other parameters for changing the architecture are:
Qucik test on area 5, run:

```
python examples/sem_seg_dense/test.py --pretrained_model sem_seg_dense/checkpoints/densedeepgcn-res-edge-ckpt_50.pth --batch_size 1 --test_path /data/deepgcn/S3DIS --task sem_seg_dense
python examples/sem_seg_dense/test.py --pretrained_model examples/sem_seg_dense/checkpoints/densedeepgcn-res-edge-ckpt_50.pth --batch_size 32 --test_path /data/deepgcn/S3DIS
```

#### Pretrained Models
Our pretrained models will be available soon.
use parameter $--pretrained_model$ to change the specific pretrained model you want.
```
python examples/sem_seg_dense/test.py --pretrained_model sem_seg_dense/checkpoints/densedeepgcn-res-edge-ckpt_50.pth --batch_size 1 --test_path /data/deepgcn/S3DIS --task sem_seg_dense
python examples/sem_seg_dense/test.py --pretrained_model examples/sem_seg_dense/checkpoints/densedeepgcn-res-edge-ckpt_50.pth --batch_size 32 --test_path /data/deepgcn/S3DIS
```

#### Visualization
Expand Down
2 changes: 1 addition & 1 deletion examples/sem_seg_dense/architecture.py
Expand Up @@ -54,5 +54,5 @@ def forward(self, inputs):

fusion = torch.max_pool2d(self.fusion_block(feats), kernel_size=[feats.shape[2], feats.shape[3]])
fusion = torch.repeat_interleave(fusion, repeats=feats.shape[2], dim=2)
return self.prediction(torch.cat((fusion, feats), dim=1)).squeeze()
return self.prediction(torch.cat((fusion, feats), dim=1)).squeeze(-1)

17 changes: 11 additions & 6 deletions examples/sem_seg_dense/opt.py
Expand Up @@ -14,7 +14,7 @@ class OptInit():
def __init__(self):
parser = argparse.ArgumentParser(description='PyTorch implementation of Deep GCN')

parser.add_argument('--phase', default='train', type=str, help='train or test(default)')
parser.add_argument('--phase', default='test', type=str, help='train or test(default)')
parser.add_argument('--use_cpu', action='store_true', help='use cpu?')

# dataset args
Expand Down Expand Up @@ -47,14 +47,14 @@ def __init__(self):
parser.add_argument('--conv', default='edge', type=str, help='graph conv layer {edge, mr}')
parser.add_argument('--act', default='relu', type=str, help='activation layer {relu, prelu, leakyrelu}')
parser.add_argument('--norm', default='batch', type=str, help='batch or instance normalization')
parser.add_argument('--bias', default=True, type=bool, help='bias of conv layer True or False')
parser.add_argument('--bias', default=True, type=bool, help='bias of conv layer True or False')
parser.add_argument('--n_filters', default=64, type=int, help='number of channels of deep features')
parser.add_argument('--n_blocks', default=28, type=int, help='number of basic blocks')
parser.add_argument('--dropout', default=0.3, type=float, help='ratio of dropout')

# dilated knn
parser.add_argument('--epsilon', default=0.2, type=float, help='stochastic epsilon for gcn')
parser.add_argument('--stochastic', default=True, type=bool, help='stochastic for gcn, True or False')
parser.add_argument('--stochastic', default=True, type=bool, help='stochastic for gcn, True or False')
args = parser.parse_args()

dir_path = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -66,7 +66,13 @@ def __init__(self):

if args.pretrained_model:
if args.pretrained_model[0] != '/':
if args.pretrained_model[0:2] == 'ex':
args.pretrained_model = os.path.join(os.path.dirname(os.path.dirname(dir_path)),
args.pretrained_model)
else:
args.pretrained_model = os.path.join(dir_path, args.pretrained_model)
args.pretrained_model = os.path.join(dir_path, args.pretrained_model)

args.save_path = os.path.join(dir_path, 'checkpoints/ckpts' + '-' + args.post)
args.logdir = os.path.join(dir_path, 'logs/' + args.post)

Expand Down Expand Up @@ -110,7 +116,7 @@ def logging_init(self):
'formatter': 'debug',
'level': logging.INFO},
'file': {'class': 'logging.FileHandler',
'filename': os.path.join(self.args.logdir, self.args.post+'.log'),
'filename': os.path.join(self.args.logdir, self.args.post + '.log'),
'formatter': 'debug',
'level': logging.INFO}},
'root': {'handlers': ('console', 'file'), 'level': 'INFO'}
Expand All @@ -123,7 +129,7 @@ def make_dir(self):
shutil.rmtree(self.args.logdir)
os.makedirs(self.args.logdir)

if not os.path.exists(self.args.save_path):
if not os.path.exists(self.args.save_path):
os.makedirs(self.args.save_path)
if not os.path.exists(self.args.train_path):
os.makedirs(self.args.train_path)
Expand All @@ -136,4 +142,3 @@ def set_seed(self, seed=0):
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

6 changes: 4 additions & 2 deletions examples/sem_seg_dense/test.py
Expand Up @@ -50,8 +50,10 @@ def test(model, loader, opt):
target_np = gt.cpu().numpy()

for cl in range(opt.n_classes):
I = np.sum(np.logical_and(pred_np == cl, target_np == cl))
U = np.sum(np.logical_or(pred_np == cl, target_np == cl))
cur_gt_mask = (target_np == cl)
cur_pred_mask = (pred_np == cl)
I = np.sum(np.logical_and(cur_pred_mask, cur_gt_mask), dtype=np.float32)
U = np.sum(np.logical_or(cur_pred_mask, cur_gt_mask), dtype=np.float32)
Is[i, cl] = I
Us[i, cl] = U

Expand Down

0 comments on commit 9123f8c

Please sign in to comment.