Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4 channel test- Training dataset has fewer elements than batch size #1114

Closed
Tobias1234 opened this issue Feb 18, 2021 · 14 comments · Fixed by #1116
Closed

4 channel test- Training dataset has fewer elements than batch size #1114

Tobias1234 opened this issue Feb 18, 2021 · 14 comments · Fixed by #1116
Labels

Comments

@Tobias1234
Copy link

Tobias1234 commented Feb 18, 2021

Hi!
I am trying to train on 4 channels (R,G,B, elevation). I am using the master branch in a Docker image with local data.

After many tries I get the same error when the run reach the train command : 'Training dataset has fewer elements than batch size.'
I tried to set batch size to 1 and increase number of epochs, I also tried to both train and validate on image 2 instead of image 3.
But I get the same error every time.

Can´t figure if it´s something in my code or the data I have to change?

Message:

File "/opt/src/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py", line 541, in setup_data
'Training dataset has fewer elements than batch size.')

My data:

https://drive.google.com/drive/folders/1ed0NpcjWOdkiSEuliszkDmytuLqVrdO5?usp=sharing

1
Image 1

2
Image 2

3
Image 3

import os
from os.path import join, basename

from rastervision.core.rv_pipeline import *
from rastervision.core.backend import *
from rastervision.core.data import *
from rastervision.core.analyzer import *
from rastervision.pytorch_backend import *
from rastervision.pytorch_learner import *
from rastervision.pytorch_backend.examples.utils import (get_scene_info,
                                                         save_image_crop)
from rastervision.pytorch_backend.examples.semantic_segmentation.utils import (
    example_multiband_transform, example_rgb_transform, `imagenet_stats,
    Unnormalize)

def get_config(runner,
               multiband: bool = True,
               external_model: bool = False,
               augment: bool = False,
               nochip: bool = False,
               test: bool = False):
    root_uri = '/opt/data/output/'
    train_image_uris = ['/opt/data/data_input/images/1.tif','/opt/data/data_input/images/2.tif']
    train_label_uris = ['/opt/data/data_input/labels/1.geojson','/opt/data/data_input/labels/2.geojson']
    train_scene_ids = ['1','2']
    train_scene_list = list(zip(train_scene_ids, train_image_uris, train_label_uris))

    val_image_uri = '/opt/data/data_input/images/3.tif'
    val_label_uri = '/opt/data/data_input/labels/3.geojson'
    val_scene_id = '3'
  

    train_scenes_input = []

    if multiband:
        # use all 4 channels
        channel_order = [0, 1, 2, 3]
        channel_display_groups = {'RGB': (0, 1, 2), 'elev': (3, )}
        aug_transform = example_multiband_transform
    else:
        # use elev, red, & green channels only
        channel_order = [3, 0, 1]
        channel_display_groups = None
        aug_transform = example_rgb_transform

    if augment:
        mu, std = imagenet_stats['mean'], imagenet_stats['std']
        mu, std = mu[channel_order], std[channel_order]

        base_transform = A.Normalize(mean=mu.tolist(), std=std.tolist())
        plot_transform = Unnormalize(mean=mu, std=std)

        aug_transform = A.to_dict(aug_transform)
        base_transform = A.to_dict(base_transform)
        plot_transform = A.to_dict(plot_transform)
    else:
        aug_transform = None
        base_transform = None
        plot_transform = None

    chip_sz = 300
    img_sz = chip_sz
    if nochip:
        chip_options = SemanticSegmentationChipOptions()
    else:
        chip_options = SemanticSegmentationChipOptions(
            window_method=SemanticSegmentationWindowMethod.sliding,
            stride=chip_sz)

    class_config = ClassConfig(
    names=['building', 'background'], colors=['red', 'black'])

    def make_scene(scene_id, image_uri, label_uri):
     
        raster_source = RasterioSourceConfig(
            uris=[image_uri],
            channel_order=channel_order,
            transformers=[StatsTransformerConfig()])
        vector_source = GeoJSONVectorSourceConfig(
            uri=label_uri, default_class_id=0, ignore_crs_field=True)
        label_source = SemanticSegmentationLabelSourceConfig(
            raster_source=RasterizedSourceConfig(
                vector_source=vector_source,
                rasterizer_config=RasterizerConfig(background_class_id=1)))
        return SceneConfig(
            id=scene_id,
            raster_source=raster_source,
            label_source=label_source)


    for scene in train_scene_list:
        train_scenes_input.append(make_scene(*scene))
        
    dataset = DatasetConfig(
    class_config=class_config,
    train_scenes=
        train_scenes_input
    ,
    validation_scenes=[
        make_scene(val_scene_id, val_image_uri, val_label_uri)
    ])
    
    

    # Use the PyTorch backend for the SemanticSegmentation pipeline.
    chip_sz = 300
    backend = PyTorchSemanticSegmentationConfig(
        data=SemanticSegmentationImageDataConfig(),
        model=SemanticSegmentationModelConfig(backbone=Backbone.resnet50),
        solver=SolverConfig(lr=1e-4, num_epochs=10, batch_sz=1, one_cycle=True))
    chip_options = SemanticSegmentationChipOptions(
        window_method=SemanticSegmentationWindowMethod.random_sample,
        chips_per_scene=10)

    return SemanticSegmentationConfig(
        root_uri=root_uri,
        dataset=dataset,
        backend=backend,
        train_chip_sz=chip_sz,
        predict_chip_sz=chip_sz)`
@AdeelH
Copy link
Collaborator

AdeelH commented Feb 18, 2021

Please share the full command that you're using to run this. Also, can you look at the zip files in /opt/data/output/chip and see how many training chips there are?

Another thing I would suggest is to change this

chip_options = SemanticSegmentationChipOptions(
        window_method=SemanticSegmentationWindowMethod.random_sample,
        chips_per_scene=10)

to

chip_options = SemanticSegmentationChipOptions(window_method=SemanticSegmentationWindowMethod.sliding)

and see if that helps.

@Tobias1234
Copy link
Author

Tobias1234 commented Feb 18, 2021

I am running

docker run --ipc=host --rm -it --name devtest27 --mount type=bind,source="C:/Users/tobbe/RV2/RV_CODE_DIR",target=/opt/src/code --mount type=bind,source="C:/Users/tobbe/RV2/RV_OUT_DIR",target=/opt/data/output --mount type=bind,source="C:/Users/tobbe/RV2/RV_DATA_INPUT_DIR",target=/opt/data/data_input quay.io/azavea/raster-vision:pytorch-latest /bin/bash

then rastervision run local code/full_train15.py

I changed to chip_options = SemanticSegmentationChipOptions(window_method=SemanticSegmentationWindowMethod.sliding) but I get the same error.

Looking in /opt/data/output/chip:
/train/img/ contains 392 .npy objects
/train/labels/ contains 392 objects

/valid/img/contains 50 .npy objects
/valid/labels/ contains 50 objects

End of run before crash:

    self.setup_data()
  File "/opt/src/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py", line 541, in setup_data
    'Training dataset has fewer elements than batch size.')
rastervision.pipeline.config.ConfigError: Training dataset has fewer elements than batch size.
/opt/data/output/Makefile:12: recipe for target '2' failed
make: *** [2] Error 1

@AdeelH
Copy link
Collaborator

AdeelH commented Feb 18, 2021

Can you share the output log for the train command? That is, everything after Running train command....

@Tobias1234
Copy link
Author

Thank you Adeel for superfast feedback! :-)

Running train command...
Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /opt/data/torch-cache/hub/checkpoints/resnet50-19c8e357.pth
100%|█████████████████████████████████████████████████████████████████████████████████| 97.8M/97.8M [00:27<00:00, 3.78MB/s]
2021-02-18 10:43:22:rastervision.pytorch_learner.learner: INFO - model=SemanticSegmentationModelConfig(backbone=<Backbone.resnet50: 'resnet50'>, pretrained=True, init_weights=None, external_def=None, type_hint='semantic_segmentation_model') solver=SolverConfig(lr=0.0001, num_epochs=10, test_num_epochs=2, test_batch_sz=4, overfit_num_steps=1, sync_interval=1, batch_sz=1, one_cycle=True, multi_stage=[], class_loss_weights=None, ignore_last_class=False, external_loss_def=None, type_hint='solver') data=SemanticSegmentationImageDataConfig(class_names=['building', 'background', 'null'], class_colors=['red', 'black', 'black'], img_sz=256, train_sz=None, num_workers=4, augmentors=['RandomRotate90', 'HorizontalFlip', 'VerticalFlip'], base_transform=None, aug_transform=None, plot_options=PlotOptions(transform=None, type_hint='plot_options'), preview_batch_limit=None, type_hint='semantic_segmentation_image_data', data_format=<SemanticSegmentationDataFormat.default: 'default'>, uri='/opt/data/output/chip', group_uris=None, group_train_sz=None, group_train_sz_rel=None, img_channels=3, channel_display_groups={'Input': (0, 1, 2)}, img_format='png', label_format='png') predict_mode=False test_mode=False overfit_mode=False eval_train=False save_model_bundle=True log_tensorboard=True run_tensorboard=False output_uri='/opt/data/output/train' type_hint='semantic_segmentation_learner'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/opt/conda/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/opt/src/rastervision_pipeline/rastervision/pipeline/cli.py", line 248, in <module>
    main()
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 722, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 697, in main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 1066, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 895, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 535, in invoke
    return callback(*args, **kwargs)
  File "/opt/src/rastervision_pipeline/rastervision/pipeline/cli.py", line 240, in run_command
    runner=runner)
  File "/opt/src/rastervision_pipeline/rastervision/pipeline/cli.py", line 217, in _run_command
    command_fn()
  File "/opt/src/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline.py", line 115, in train
    backend.train(source_bundle_uri=self.config.source_bundle_uri)
  File "/opt/src/rastervision_pytorch_backend/rastervision/pytorch_backend/pytorch_learner_backend.py", line 74, in train
    learner = self.learner_cfg.build(self.tmp_dir, training=True)
  File "/opt/src/rastervision_pytorch_learner/rastervision/pytorch_learner/semantic_segmentation_learner_config.py", line 157, in build
    training=training)
  File "/opt/src/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py", line 128, in __init__
    self.setup_training(loss_def_path=loss_def_path)
  File "/opt/src/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py", line 182, in setup_training
    self.setup_data()
  File "/opt/src/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py", line 541, in setup_data
    'Training dataset has fewer elements than batch size.')
rastervision.pipeline.config.ConfigError: Training dataset has fewer elements than batch size.
/opt/data/output/Makefile:12: recipe for target '2' failed
make: *** [2] Error 1
root@83bb64ce803f:/opt/src#

@AdeelH
Copy link
Collaborator

AdeelH commented Feb 18, 2021

The error seems to be occurring because the train stage is incorrectly assuming that there are 3 channels only and therefore looking for .png files instead of .npy files. We can specify the correct number of channels explicitly by changing

data=SemanticSegmentationImageDataConfig(),

to

data=SemanticSegmentationImageDataConfig(img_channels=len(channel_order)),

I think this should fix the problem.

@AdeelH AdeelH added the bug label Feb 18, 2021
@Tobias1234
Copy link
Author

Should I define something more in my code?

Getting
TypeError: 'AxesSubplot' object is not subscriptable

Running train command...
2021-02-18 11:26:03:rastervision.pytorch_learner.learner: INFO - model=SemanticSegmentationModelConfig(backbone=<Backbone.resnet50: 'resnet50'>, pretrained=True, init_weights=None, external_def=None, type_hint='semantic_segmentation_model') solver=SolverConfig(lr=0.0001, num_epochs=10, test_num_epochs=2, test_batch_sz=4, overfit_num_steps=1, sync_interval=1, batch_sz=1, one_cycle=True, multi_stage=[], class_loss_weights=None, ignore_last_class=False, external_loss_def=None, type_hint='solver') data=SemanticSegmentationImageDataConfig(class_names=['building', 'background', 'null'], class_colors=['red', 'black', 'black'], img_sz=256, train_sz=None, num_workers=4, augmentors=['RandomRotate90', 'HorizontalFlip', 'VerticalFlip'], base_transform=None, aug_transform=None, plot_options=PlotOptions(transform=None, type_hint='plot_options'), preview_batch_limit=None, type_hint='semantic_segmentation_image_data', data_format=<SemanticSegmentationDataFormat.default: 'default'>, uri='/opt/data/output/chip', group_uris=None, group_train_sz=None, group_train_sz_rel=None, img_channels=4, channel_display_groups={'Input': (0, 1, 2, 3)}, img_format='npy', label_format='png') predict_mode=False test_mode=False overfit_mode=False eval_train=False save_model_bundle=True log_tensorboard=True run_tensorboard=False output_uri='/opt/data/output/train' type_hint='semantic_segmentation_learner'
2021-02-18 11:26:06:rastervision.pytorch_learner.learner: INFO - train_ds: 392 items
2021-02-18 11:26:06:rastervision.pytorch_learner.learner: INFO - valid_ds: 50 items
2021-02-18 11:26:06:rastervision.pytorch_learner.learner: INFO - test_ds: 50 items
Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/opt/conda/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/opt/src/rastervision_pipeline/rastervision/pipeline/cli.py", line 248, in <module>
    main()
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 722, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 697, in main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 1066, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 895, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 535, in invoke
    return callback(*args, **kwargs)
  File "/opt/src/rastervision_pipeline/rastervision/pipeline/cli.py", line 240, in run_command
    runner=runner)
  File "/opt/src/rastervision_pipeline/rastervision/pipeline/cli.py", line 217, in _run_command
    command_fn()
  File "/opt/src/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline.py", line 115, in train
    backend.train(source_bundle_uri=self.config.source_bundle_uri)
  File "/opt/src/rastervision_pytorch_backend/rastervision/pytorch_backend/pytorch_learner_backend.py", line 75, in train
    learner.main()
  File "/opt/src/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py", line 142, in main
    self.plot_dataloaders(self.preview_batch_limit)
  File "/opt/src/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py", line 943, in plot_dataloaders
    batch_limit)
  File "/opt/src/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py", line 936, in plot_dataloader
    self.plot_batch(x, y, output_path, batch_limit=batch_limit)
  File "/opt/src/rastervision_pytorch_learner/rastervision/pytorch_learner/semantic_segmentation_learner.py", line 199, in plot_batch
    self.plot_xyz(ax, x[i], y[i])
  File "/opt/src/rastervision_pytorch_learner/rastervision/pytorch_learner/semantic_segmentation_learner.py", line 223, in plot_xyz
    img_axes = ax[:len(channel_groups)]
TypeError: 'AxesSubplot' object is not subscriptable
/opt/data/output/Makefile:12: recipe for target '2' failed
make: *** [2] Error 1

@AdeelH
Copy link
Collaborator

AdeelH commented Feb 18, 2021

You seem to have run up against an edge case in plotting samples from the dataset caused by batch size = 1. Good catch. This is another bug. Increasing the batch size should fix this particular error.

I also notice that you are setting but not passing in the channel_display_groups. This will plot the elevation image separately from the RGB image. You can do that like so:

data=SemanticSegmentationImageDataConfig(img_channels=len(channel_order), channel_display_groups=channel_display_groups),

@Tobias1234
Copy link
Author

Tobias1234 commented Feb 18, 2021

Seems to work now.
It crashes during training, but I guess it´s a CPU problem? I cleaned the output folder before I run. CPU reaches 80 % sometimes before it crashes. I havent seen this before in the start of the run
make: Warning: File '/opt/data/output/Makefile' has modification time 0.51 s in the future
`
I have a good graphic card, NVIDIA Geforce RTX 3060 TI, but only 16GB RAM (Will upgrade to 32 GB)

cpu

@AdeelH
Copy link
Collaborator

AdeelH commented Feb 18, 2021

You need to enable GPU usage when running docker. Depending on your Docker version, you will need to pass in either --gpus=all or --runtime=nvidia to docker run.

@greenhawktobias
Copy link

greenhawktobias commented Feb 22, 2021

The GPU on Docker is a bit cumbersome on Windows(working on it). But I uppdated RAM so the pipeline is running now. But now I get "IndexError: boolean index did not match indexed array along dimension 0; dimension is 256 but corresponding boolean dimension is 300"

Running predict command...
2021-02-22 10:13:29:rastervision.pytorch_learner.learner: INFO - Physical CPUs: 8
2021-02-22 10:13:29:rastervision.pytorch_learner.learner: INFO - Logical CPUs: 16
2021-02-22 10:13:29:rastervision.pytorch_learner.learner: INFO - Total memory:  15.59 GB
2021-02-22 10:13:29:rastervision.pytorch_learner.learner: INFO - Size of /opt/data volume:  250.98 GB
2021-02-22 10:13:29:rastervision.pytorch_learner.learner: INFO - Python version: 3.6.12 |Anaconda, Inc.| (default, Sep  8 2020, 23:10:56)
[GCC 7.3.0]
2021-02-22 10:13:29:rastervision.pytorch_learner.learner: INFO - nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_19:24:38_PDT_2019
Cuda compilation tools, release 10.2, V10.2.89

/bin/sh: 1: nvidia-smi: not found
2021-02-22 10:13:29:rastervision.pytorch_learner.learner: INFO -
2021-02-22 10:13:29:rastervision.pytorch_learner.learner: INFO - Devices:
2021-02-22 10:13:29:rastervision.pytorch_learner.learner: INFO - PyTorch version: 1.7.1
2021-02-22 10:13:29:rastervision.pytorch_learner.learner: INFO - CUDA available: False
2021-02-22 10:13:29:rastervision.pytorch_learner.learner: INFO - CUDA version: 10.2
2021-02-22 10:13:29:rastervision.pytorch_learner.learner: INFO - CUDNN version: 7605
2021-02-22 10:13:29:rastervision.pytorch_learner.learner: INFO - Number of CUDA devices: 0
2021-02-22 10:13:29:rastervision.pytorch_learner.learner: INFO - Loading model weights from: /opt/data/tmp/tmplz45dsqa/model-bundle/model.pth
2021-02-22 10:13:30:rastervision.core.rv_pipeline.rv_pipeline: INFO - Making predictions for scene
Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/opt/conda/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/opt/src/rastervision_pipeline/rastervision/pipeline/cli.py", line 248, in <module>
    main()
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 722, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 697, in main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 1066, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 895, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 535, in invoke
    return callback(*args, **kwargs)
  File "/opt/src/rastervision_pipeline/rastervision/pipeline/cli.py", line 240, in run_command
    runner=runner)
  File "/opt/src/rastervision_pipeline/rastervision/pipeline/cli.py", line 217, in _run_command
    command_fn()
  File "/opt/src/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline.py", line 165, in predict
    for s in dataset.validation_scenes
  File "/opt/src/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline.py", line 159, in _predict
    labels = self.predict_scene(scene, self.backend)
  File "/opt/src/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline.py", line 200, in predict_scene
    predict_batch(batch_chips, batch_windows)
  File "/opt/src/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline.py", line 187, in predict_batch
    batch_labels)
  File "/opt/src/rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation.py", line 158, in post_process_batch
    labels.mask_fill(window, nodata_mask, fill_value=null_class_id)
  File "/opt/src/rastervision_core/rastervision/core/data/label/semantic_segmentation_labels.py", line 174, in mask_fill
    self.window_to_label_arr[window][mask] = fill_value
IndexError: boolean index did not match indexed array along dimension 0; dimension is 256 but corresponding boolean dimension is 300
/opt/data/output/Makefile:15: recipe for target '3' failed
make: *** [3] Error 1

@AdeelH
Copy link
Collaborator

AdeelH commented Feb 22, 2021

You are setting but not passing in img_sz to the config.

Change

data=SemanticSegmentationImageDataConfig(img_channels=len(channel_order)),

to

data=SemanticSegmentationImageDataConfig(img_channels=len(channel_order), img_sz=img_sz),

@Tobias1234
Copy link
Author

Works now. Thanks!

@Tobias1234
Copy link
Author

Now and the running this script I get UnpicklingError: unpickling stack underflow. Why is that?
Something else than the code maybe?

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /opt/data/torch-cache/hub/checkpoints/resnet50-19c8e357.pth
  5%|███▉                                                                           | 4.89M/97.8M [00:12<03:52, 419kB/s]
Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/opt/conda/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/opt/src/rastervision_pipeline/rastervision/pipeline/cli.py", line 248, in <module>
    main()
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 722, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 697, in main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 1066, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 895, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.6/site-packages/click/core.py", line 535, in invoke
    return callback(*args, **kwargs)
  File "/opt/src/rastervision_pipeline/rastervision/pipeline/cli.py", line 240, in run_command
    runner=runner)
  File "/opt/src/rastervision_pipeline/rastervision/pipeline/cli.py", line 217, in _run_command
    command_fn()
  File "/opt/src/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline.py", line 115, in train
    backend.train(source_bundle_uri=self.config.source_bundle_uri)
  File "/opt/src/rastervision_pytorch_backend/rastervision/pytorch_backend/pytorch_learner_backend.py", line 74, in train
    learner = self.learner_cfg.build(self.tmp_dir, training=True)
  File "/opt/src/rastervision_pytorch_learner/rastervision/pytorch_learner/semantic_segmentation_learner_config.py", line 157, in build
    training=training)
  File "/opt/src/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py", line 155, in __init__
    self.setup_model(model_def_path=model_def_path)
  File "/opt/src/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py", line 273, in setup_model
    self.model = self.build_model()
  File "/opt/src/rastervision_pytorch_learner/rastervision/pytorch_learner/semantic_segmentation_learner.py", line 41, in build_model
    pretrained_backbone=pretrained)
  File "/opt/conda/lib/python3.6/site-packages/torchvision/models/segmentation/segmentation.py", line 22, in _segm_resnet
    replace_stride_with_dilation=[False, True, True])
  File "/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py", line 265, in resnet50
    **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py", line 227, in _resnet
    progress=progress)
  File "/opt/conda/lib/python3.6/site-packages/torch/hub.py", line 559, in load_state_dict_from_url
    return torch.load(cached_file, map_location=map_location)
  File "/opt/conda/lib/python3.6/site-packages/torch/serialization.py", line 595, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "/opt/conda/lib/python3.6/site-packages/torch/serialization.py", line 764, in _legacy_load
    magic_number = pickle_module.load(f, **pickle_load_args)
_pickle.UnpicklingError: unpickling stack underflow
/opt/data/output/Makefile:12: recipe for target '2' failed
make: *** [2] Error 1

@AdeelH
Copy link
Collaborator

AdeelH commented Feb 25, 2021

Maybe the download didn't complete?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
3 participants