Skip to content

Commit e461ec0

Browse files
authored
fixing Win failed import (Lightning-AI#1163)
* version * try fix distrib * update try import
1 parent 49d000c commit e461ec0

File tree

8 files changed

+18
-19
lines changed

8 files changed

+18
-19
lines changed

.github/workflows/rebase.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ name: Automatic Rebase
22
# https://github.com/marketplace/actions/automatic-rebase
33

44
on:
5-
issue_comment:
6-
types: [created]
5+
- pull_request
6+
77
jobs:
88
rebase:
99
name: Rebase

pytorch_lightning/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Root package info."""
22

3-
__version__ = '0.7.1'
3+
__version__ = '0.7.2-dev'
44
__author__ = 'William Falcon et al.'
55
__author_email__ = 'waf2107@columbia.edu'
66
__license__ = 'Apache-2.0'

pytorch_lightning/core/hooks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222

2323
try:
2424
from apex import amp
25-
26-
APEX_AVAILABLE = True
2725
except ImportError:
2826
APEX_AVAILABLE = False
27+
else:
28+
APEX_AVAILABLE = True
2929

3030

3131
class ModelHooks(torch.nn.Module):

pytorch_lightning/core/lightning.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111
from torch import Tensor
12-
from torch.distributed import init_process_group
12+
import torch.distributed as torch_distrib
1313
from torch.nn.parallel import DistributedDataParallel
1414
from torch.optim import Adam
1515
from torch.optim.optimizer import Optimizer
@@ -24,10 +24,10 @@
2424

2525
try:
2626
import torch_xla.core.xla_model as xm
27-
XLA_AVAILABLE = True
28-
2927
except ImportError:
3028
XLA_AVAILABLE = False
29+
else:
30+
XLA_AVAILABLE = True
3131

3232

3333
class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
@@ -859,7 +859,7 @@ def init_ddp_connection(self):
859859

860860
root_node = self.trainer.resolve_root_node_address(root_node)
861861
os.environ['MASTER_ADDR'] = root_node
862-
init_process_group('nccl', rank=proc_rank, world_size=world_size)
862+
torch_distrib.init_process_group('nccl', rank=proc_rank, world_size=world_size)
863863

864864
def configure_apex(
865865
self,

pytorch_lightning/trainer/auto_mix_precision.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
1+
import logging as log
22
from abc import ABC
33

44
try:
55
from apex import amp
6-
7-
APEX_AVAILABLE = True
86
except ImportError:
97
APEX_AVAILABLE = False
10-
import logging as log
8+
else:
9+
APEX_AVAILABLE = True
1110

1211

1312
class TrainerAMPMixin(ABC):

pytorch_lightning/trainer/data_loading.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC, abstractmethod
22
from typing import Union, List, Tuple, Callable
33

4-
import torch.distributed as dist
4+
import torch.distributed as torch_distrib
55
from torch.utils.data import SequentialSampler, DataLoader
66
from torch.utils.data.distributed import DistributedSampler
77

@@ -224,7 +224,7 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
224224
# get the function we'll use to get data
225225
if self.use_ddp or self.use_ddp2:
226226
# all processes wait until data download has happened
227-
dist.barrier()
227+
torch_distrib.barrier()
228228

229229
# data download/load on TPU
230230
elif self.use_tpu and XLA_AVAILABLE:

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch
1010
from torch import optim
11-
import torch.distributed as dist
11+
import torch.distributed as torch_distrib
1212
import torch.multiprocessing as mp
1313
from torch.optim.optimizer import Optimizer
1414
from torch.utils.data import DataLoader
@@ -748,7 +748,7 @@ def run_pretrain_routine(self, model: LightningModule):
748748
self.logger.save()
749749

750750
if self.use_ddp or self.use_ddp2:
751-
dist.barrier()
751+
torch_distrib.barrier()
752752

753753
# wait for all models to restore weights
754754
if self.on_tpu and XLA_AVAILABLE:

pytorch_lightning/trainer/training_io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@
100100
from typing import Union
101101

102102
import torch
103-
import torch.distributed as dist
103+
import torch.distributed as torch_distrib
104104

105105
from pytorch_lightning.core.lightning import LightningModule
106106
from pytorch_lightning.loggers import LightningLoggerBase
@@ -177,7 +177,7 @@ def restore_weights(self, model):
177177
# wait for all models to restore weights
178178
if self.use_ddp or self.use_ddp2:
179179
# wait for all processes to catch up
180-
dist.barrier()
180+
torch_distrib.barrier()
181181

182182
# wait for all models to restore weights
183183
if self.on_tpu and XLA_AVAILABLE:

0 commit comments

Comments
 (0)