Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing different device issue on multi-layer UNet and UNETR #399

Merged
merged 11 commits into from
Apr 18, 2022
11 changes: 6 additions & 5 deletions GANDLF/models/light_unet_multilayer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
"""
Implementation of UNet
Implementation of Light UNet
"""
from torch.nn import ModuleList

from GANDLF.models.seg_modules.DownsamplingModule import DownsamplingModule
from GANDLF.models.seg_modules.EncodingModule import EncodingModule
Expand Down Expand Up @@ -58,10 +59,10 @@ def __init__(
network_kwargs=self.network_kwargs,
)

self.ds = []
self.en = []
self.us = []
self.de = []
self.ds = ModuleList([])
self.en = ModuleList([])
self.us = ModuleList([])
self.de = ModuleList([])

for _ in range(0, self.num_layers):
self.ds.append(
Expand Down
11 changes: 7 additions & 4 deletions GANDLF/models/unet_multilayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""
Implementation of UNet
"""
from torch.nn import ModuleList

from GANDLF.models.seg_modules.DownsamplingModule import DownsamplingModule
from GANDLF.models.seg_modules.EncodingModule import EncodingModule
Expand Down Expand Up @@ -58,10 +59,10 @@ def __init__(
network_kwargs=self.network_kwargs,
)

self.ds = []
self.en = []
self.us = []
self.de = []
self.ds = ModuleList([])
self.en = ModuleList([])
self.us = ModuleList([])
self.de = ModuleList([])

for i_lay in range(0, self.num_layers):
self.ds.append(
Expand Down Expand Up @@ -127,9 +128,11 @@ def forward(self, x):
"""
y = []
y.append(self.ins(x))
print("x.device:", x.device)

# [downsample --> encode] x num layers
for i in range(0, self.num_layers):
print("y[i].device:", y[i].device)
temp = self.ds[i](y[i])
y.append(self.en[i](temp))

Expand Down
7 changes: 4 additions & 3 deletions GANDLF/models/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from GANDLF.models.seg_modules.out_conv import out_conv
import torch
import torch.nn as nn
from torch.nn import ModuleList
import numpy as np
import math
from GANDLF.utils.generic import checkPatchDimensions
Expand Down Expand Up @@ -264,7 +265,7 @@ def __init__(
self.out_layers = out_layers
self.num_layers = num_layers
self.embed = _Embedding(img_size, patch_size, in_feats, embed_size, Conv)
self.layers = []
self.layers = ModuleList([])

for _ in range(0, num_layers):
layer = _TransformerLayer(
Expand Down Expand Up @@ -369,8 +370,8 @@ def __init__(
Norm=self.Norm,
)

self.upsampling = []
self.convs = []
self.upsampling = ModuleList([])
self.convs = ModuleList([])

for i in range(0, self.depth - 1):
# add deconv blocks
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def run(self):
"pandas",
"pylint",
"scikit-learn>=0.23.2",
"scikit-image>=0.19.1",
"pickle5>=0.0.11",
"setuptools",
"seaborn",
"pyyaml",
"tiffslide",
"scikit-image",
"matplotlib",
"requests>=2.25.0",
"pyvips",
Expand Down
Loading