Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Fundamental research to develop new architectures for foundation models and A(G)

## News

- December, 2023: [LongNet](torchscale/model/LongNet.py) and [LongViT](examples/longvit/README.md) released
- October, 2023: Update RMSNorm and SwiGLU as the default module in RetNet
- November, 2022: TorchScale 0.1.1 released [[Paper](https://arxiv.org/abs/2211.13184)] [[PyPI](https://pypi.org/project/torchscale/)]

Expand All @@ -37,6 +38,18 @@ cd torchscale
pip install -e .
```

For faster training install [Flash Attention](https://github.com/Dao-AILab/flash-attention) for Turing, Ampere, Ada, or Hopper GPUs:
```
pip install flash-attn
```
or [xFormers](https://github.com/facebookresearch/xformers) for Volta, Turing, Ampere, Ada, or Hopper GPUs:
```
# cuda 11.8 version
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118
# cuda 12.1 version
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121
```

## Getting Started

It takes only several lines of code to create a model with the above fundamental research features enabled. Here is how to quickly obtain a BERT-like encoder:
Expand Down Expand Up @@ -85,6 +98,21 @@ It takes only several lines of code to create a RetNet model:
>>> print(retnet)
```

For LongNet models ([Flash Attention](https://github.com/Dao-AILab/flash-attention) required):
```python
>>> import torch
>>> from torchscale.architecture.config import EncoderConfig, DecoderConfig
>>> from torchscale.model.longnet import LongNetEncoder, LongNetDecoder

# Creating a LongNet encoder with the dilated pattern of segment_length=[2048,4096] and dilated_ratio=[1,2]
>>> config = EncoderConfig(vocab_size=64000, segment_length='[2048,4096]', dilated_ratio='[1,2]', flash_attention=True)
>>> longnet = LongNetEncoder(config)

# Creating a LongNet decoder with the dilated pattern of segment_length=[2048,4096] and dilated_ratio=[1,2]
>>> config = DecoderConfig(vocab_size=64000, segment_length='[2048,4096]', dilated_ratio='[1,2]', flash_attention=True)
>>> longnet = LongNetDecoder(config)
```

## Key Features

- [DeepNorm to improve the training stability of Post-LayerNorm Transformers](https://arxiv.org/abs/2203.00555)
Expand Down Expand Up @@ -142,6 +170,8 @@ We have examples of how to use TorchScale in the following scenarios/tasks:

- Vision

* [LongViT](examples/longvit/README.md)

* ViT/BEiT [In progress]

- Speech
Expand Down Expand Up @@ -228,6 +258,26 @@ If you find this repository useful, please consider citing our work:
}
```

```
@article{longnet,
author={Jiayu Ding and Shuming Ma and Li Dong and Xingxing Zhang and Shaohan Huang and Wenhui Wang and Nanning Zheng and Furu Wei},
title = {{LongNet}: Scaling Transformers to 1,000,000,000 Tokens},
journal = {ArXiv},
volume = {abs/2307.02486},
year = {2023}
}
```

```
@article{longvit,
title = {When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology},
author = {Wenhui Wang and Shuming Ma and Hanwen Xu and Naoto Usuyama and Jiayu Ding and Hoifung Poon and Furu Wei},
journal = {ArXiv},
volume = {abs/2312.03558},
year = {2023}
}
```

## Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a
Expand Down
39 changes: 39 additions & 0 deletions examples/fairseq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,45 @@ python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
--use-xmoe
```

### LongNet Model

```bash
cd examples/fairseq/
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
${PATH_TO_DATA} \
--num-workers 2 \
--activation-fn gelu \
--share-decoder-input-output-embed \
--validate-interval-updates 1000 \
--save-interval-updates 1000 \
--no-epoch-checkpoints \
--memory-efficient-fp16 \
--fp16-init-scale 4 \
--arch lm_base \
--task language_modeling \
--sample-break-mode none \
--tokens-per-sample 4096 \
--optimizer adam --adam-betas "(0.9, 0.98)" \
--adam-eps 1e-08 \
--clip-norm 0.0 \
--lr 5e-4 \
--lr-scheduler polynomial_decay \
--warmup-updates 750 \
--dropout 0.1 \
--attention-dropout 0.1 \
--weight-decay 0.01 \
--batch-size 4 \
--update-freq 1 \
--required-batch-size-multiple 1 \
--total-num-update 50000 \
--max-update 50000 \
--seed 1 \
--ddp-backend=c10d \
--flash-attention \
--segment-length [2048,4096] \
--dilated-ratio [1,2]
```

## Example: Machine Translation

### Data Format
Expand Down
41 changes: 40 additions & 1 deletion examples/fairseq/models/language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from torchscale.architecture.config import DecoderConfig
from torchscale.architecture.decoder import Decoder
from torchscale.model.LongNet import LongNetDecoder

DEFAULT_MAX_TARGET_POSITIONS = 1024
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -196,6 +197,19 @@ class LanguageConfig(FairseqDataclass):
xpos_scale_base: Optional[int] = field(
default=512,
)
flash_attention: Optional[bool] = field(
default=False,
)
seq_parallel: Optional[bool] = field(
default=False,
)
segment_length: Optional[str] = field(
default='',
)
dilated_ratio: Optional[str] = field(
default='',
)



@register_model("lm", dataclass=LanguageConfig)
Expand Down Expand Up @@ -256,7 +270,13 @@ def build_model(cls, args, task):
config = DecoderConfig()
config.override(args)

decoder = LMDecoder(
if args.segment_length != '':
assert args.dilated_ratio != ''
DECODER_CLASS = LongNetLMDecoder
else:
DECODER_CLASS = LMDecoder

decoder = DECODER_CLASS(
config,
embed_tokens,
embed_positions,
Expand Down Expand Up @@ -291,6 +311,25 @@ def reorder_incremental_state_scripting(
incremental_state[module][key] = result


class LongNetLMDecoder(LongNetDecoder, FairseqIncrementalDecoder):
def forward(self, src_tokens, **kwargs):
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
return super().forward(src_tokens, self_attn_padding_mask, **kwargs)

def max_positions(self):
return self.embed_positions.max_positions

def reorder_incremental_state_scripting(
self,
incremental_state,
new_order,
):
for module in incremental_state:
for key in incremental_state[module]:
result = incremental_state[module][key].index_select(0, new_order)
incremental_state[module][key] = result


@register_model_architecture("lm", "lm_base")
def base_lm_architecture(args):
# backward compatibility for older model checkpoints
Expand Down
71 changes: 71 additions & 0 deletions examples/longvit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# [(LongViT) When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology](https://arxiv.org/abs/2312.03558)

**LongViT** is a vision Transformer that can process gigapixel images (e.g., 32,768x32,768 images) in an end-to-end manner. We split the image into millions of patches and employ [LongNet](https://arxiv.org/abs/2307.02486) to directly model the extremely long sequence. We apply LongViT in the field of computational pathology and achieve remarkable performance on cancer subtyping and survival prediction tasks.


## Setup
```
pip install -r requirements.txt
pip install git+https://github.com/shumingma/fairseq.git@moe
pip install -v -U git+https://github.com/facebookresearch/xformers.git@v0.0.20#egg=xformers
```


## Pretraining

We perform self-supervised pretraining on TCGA diagnostic slides using [DINO](https://arxiv.org/abs/2104.14294) objective. The detailed instructions can be found at [`get_started_for_tcga_pretraining.md`](get_started/get_started_for_tcga_pretraining.md).

The link to the pretrained LongViT model on TCGA diagnostic slides:
- [`LongViT`](https://conversationhub.blob.core.windows.net/beit-share-public/longvit/longvit_small_patch32_1024.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D): #layer=12; hidden=384; FFN factor=4x; #head=16; patch=32x32


## Fine-tuning on Subtyping Classification

We perform finetuning on cancer subtyping on images with sizes up to 32,768x32,768 (1M patches). The detailed instructions can be found at [`get_started_for_tcga_subtyping.md`](get_started/get_started_for_tcga_subtyping.md).


## Fine-tuning on Survival Prediction

We perform finetuning on survival prediction on images with sizes up to 32,768x32,768 (1M patches). The detailed instructions can be found at [`get_started_for_tcga_survival_prediction.md`](get_started/get_started_for_tcga_survival_prediction.md).


## Citation

If you find this repository useful, please consider citing our work:
```
@article{longvit,
title={When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology},
author={Wang, Wenhui and Ma, Shuming and Xu, Hanwen and Usuyama, Naoto and Ding, Jiayu and Poon, Hoifung and Wei, Furu},
journal={arXiv preprint arXiv:2312.03558},
year={2023}
}

@article{longnet,
title={LongNet: Scaling transformers to 1,000,000,000 tokens},
author={Ding, Jiayu and Ma, Shuming and Dong, Li and Zhang, Xingxing and Huang, Shaohan and Wang, Wenhui and Zheng, Nanning and Wei, Furu},
journal={arXiv preprint arXiv:2307.02486},
year={2023}
}

@article{torchscale,
title={TorchScale: Transformers at scale},
author={Ma, Shuming and Wang, Hongyu and Huang, Shaohan and Wang, Wenhui and Chi, Zewen and Dong, Li and Benhaim, Alon and Patra, Barun and Chaudhary, Vishrav and Song, Xia and others},
journal={arXiv preprint arXiv:2211.13184},
year={2022}
}
```


## Acknowledgement

This repository is built using the [BEiT-3](https://github.com/microsoft/unilm/tree/master/beit3), the [MCAT](https://github.com/mahmoodlab/MCAT), the [DINO](https://github.com/facebookresearch/dino), the [HIPT](https://github.com/mahmoodlab/HIPT) repository and the [timm](https://github.com/rwightman/pytorch-image-models) library.


## License
This project is licensed under the license found in the LICENSE file in the root directory of this source tree.

[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)

### Contact Information

For help or issues using LongViT models, please submit a GitHub issue.
78 changes: 78 additions & 0 deletions examples/longvit/data_preprocessing/cache_transformed_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os
import sys
import torch
import random
import argparse
from PIL import Image, ImageFilter, ImageOps
from multiprocessing import Pool, cpu_count
from timm.data.transforms import RandomResizedCropAndInterpolation
import torchvision.transforms as transforms

Image.MAX_IMAGE_PIXELS = 6400000000


def build_transform(input_size):
train_interpolation = "bicubic"
t = [
RandomResizedCropAndInterpolation(input_size, scale=(0.5, 1.0), interpolation=train_interpolation),
transforms.RandomHorizontalFlip(),
]
t = transforms.Compose(t)

return t


def pil_loader(path):
with open(path, "rb") as f:
img = Image.open(f)
return img.convert("RGB")


def save_image(transformed_img, output_image_path):
if isinstance(transformed_img, torch.Tensor):
transformed_img = transforms.ToPILImage()(transformed_img)
transformed_img.save(output_image_path)


def get_image_files(input_dir):
for root, _, files in os.walk(input_dir):
for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
yield os.path.join(root, file)


def transform_and_save_crops(args):
input_path, input_dir, output_dir, transform = args
print(input_path)
file_basename = os.path.basename(input_path)

img = pil_loader(input_path)
transformed_img = transform(img)
output_image_path = os.path.join(output_dir, file_basename)
save_image(transformed_img, output_image_path)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Save transformed images in a directory.')
parser.add_argument('input_dir', help='Path to the input directory.')
parser.add_argument('output_dir', help='Path to the output directory.')
parser.add_argument('-p', '--processes', type=int, default=cpu_count(), help='Number of processes to use. Default: number of CPU cores')
parser.add_argument('--input_size', type=int, default=16384, help='input image size')
args = parser.parse_args()

input_dir = args.input_dir
output_dir = args.output_dir
num_processes = args.processes
input_size = args.input_size
print("num_processes: {}".format(num_processes))
print("input_size: {}".format(input_size))

transform = build_transform(input_size=input_size)

image_files = list(get_image_files(input_dir))
task_args = [(file, input_dir, output_dir, transform) for file in image_files]

os.makedirs(output_dir, exist_ok=True)

with Pool(processes=num_processes) as pool:
pool.map(transform_and_save_crops, task_args)
45 changes: 45 additions & 0 deletions examples/longvit/data_preprocessing/convert_wsi_to_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
import glob
import argparse
import openslide

from PIL import Image
from concurrent.futures import ProcessPoolExecutor


def convert_wsi_to_images(slide_path, image_path, target_size, level=0):
slide = openslide.open_slide(slide_path)
level_dims = slide.level_dimensions
region = slide.read_region((0,0), level, level_dims[level])
region = region.convert("RGB")
print("convert: {}({}) -> {}".format(slide_path, region.size, image_path))
resized_img = region.resize((target_size, target_size), Image.BICUBIC)
resized_img.save(image_path)


def process_slides(input_folder, output_folder, target_size, level=0):
if not os.path.exists(output_folder):
os.makedirs(output_folder)

slide_paths = glob.glob(os.path.join(input_folder, "*.svs"))

with ProcessPoolExecutor(max_workers=1) as executor:
for slide_path in slide_paths:
image_path = os.path.join(output_folder, os.path.basename(slide_path).split(".svs")[0] + ".jpg")
executor.submit(convert_wsi_to_images, slide_path, image_path, target_size, level=level)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert slides into images")
parser.add_argument("input_folder", type=str, help="")
parser.add_argument("output_folder", type=str, help="")
parser.add_argument("target_size", type=int, help="")
parser.add_argument("level", type=int, help="")

args = parser.parse_args()
input_folder = args.input_folder
output_folder = args.output_folder
target_size = args.target_size
level = args.level

process_slides(input_folder, output_folder, target_size, level=level)
Loading