Skip to content

Latest commit

 

History

History
36 lines (30 loc) · 1.58 KB

00769.md

File metadata and controls

36 lines (30 loc) · 1.58 KB
title tags categories
DJLのPyTorchバックエンドでMPS (Metal Performance Shaders) を使うメモ
Java
DJL
Machine Learning
PyTorch
MPS
Dev
Java
ai
djl

DJL (Deep Java Library) 0.20.0以降で MPS が使えるようになっていました。 サンプルが見当たらなかったので試したメモ。

サンプルコードは こちら です。Apple M2 Pro、メモリ32 GB、macOS 13.5.2で試しました。

次のようにDeviceインスタンスをDevice.of("mps", 0)で作れば良いようです。

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;

public class Main {
	public static void main(String[] args) {
		int dimension = 1024;
		Device device = Device.of("mps", 0);
		//Device device = Device.cpu();
		System.out.println(device.isGpu()); // false
		try (NDManager manager = NDManager.newBaseManager(device)) {
			NDArray array1 = manager.randomUniform(0, 1, new Shape(dimension, dimension));
			NDArray array2 = manager.randomUniform(0, 1, new Shape(dimension, dimension));
			NDArray result = array1.add(array2).mul(10).matMul(array1.transpose()).div(5);
			System.out.println(result);
		}
	}
}

MPS自体はGPUではないですが、MPSのAPIを使うことで、GPUが利用されるため、MPSを使ったコードを実行するとアクティビティモニタで % GPU の数字が0より大きくなります。