Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training Transformer on TPU #525

Closed
DevKretov opened this issue Apr 28, 2020 · 1 comment
Closed

Training Transformer on TPU #525

DevKretov opened this issue Apr 28, 2020 · 1 comment

Comments

@DevKretov
Copy link

DevKretov commented Apr 28, 2020

Description

Hello, I was wondering, how large can the batch size be considering TPU training? Now I'm training vanilla Transformer model in Colab and I can barely fit TPU memory. My batch size is 128, sequences are padded with padded_batch function, max_len is 512. It seems to me that I'm missing something, because it's a bit suspicious that TPU cannot handle batches of higher magnitude (like 2048).

The thing that I tried to establish is to run TPU profiler, but I could not do it since the model doesn't output anything to keep track of.

That's why, my question is, what are the best practices of training Trax transformer on TPUs?

Error log

RuntimeError: Resource exhausted: Ran out of memory in memory space hbm. Used 12.64G of 8.00G hbm. Exceeded hbm capacity by 4.64G.

Total hbm usage >= 12.64G:
    reserved        529.00M 
    program          12.13G 
    arguments       unknown size 

Output size unknown.

Program hbm requirement 12.13G:
    reserved           4.0K
    global           196.0K
    HLO temp         12.13G (58.5% utilization: Unpadded (7.09G) Padded (12.12G), 0.1% fragmentation (10.34M))

  Largest program allocations in hbm:

  1. Size: 937.50M
     Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((2,), (0,)), ((), ()))\n                                 precision=None ]"
     Shape: f32[64,128,30000]{1,2,0:T(8,128)}
     Unpadded size: 937.50M
     XLA label: %fusion.1546 = (f32[64,128]{1,0:T(8,128)}, f32[64,128]{1,0:T(8,128)}, f32[64,128,30000]{1,2,0:T(8,128)}) fusion(f32[64,128]{1,0:T(8,128)} %fusion.9002.remat3, f32[64,128]{1,0:T(8,128)} %fusion.28213.remat, f32[30000]{0:T(1024)} %get-tuple-element.4759, f32...
     Allocation type: HLO temp
     ==========================

  2. Size: 512.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
     Unpadded size: 128.00M
     Extra memory due to padding: 384.00M (4.0x expansion)
     XLA label: %copy.249 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4769), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  3. Size: 512.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
     Unpadded size: 128.00M
     Extra memory due to padding: 384.00M (4.0x expansion)
     XLA label: %copy.248 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4765), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  4. Size: 512.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
     Unpadded size: 128.00M
     Extra memory due to padding: 384.00M (4.0x expansion)
     XLA label: %copy.247 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4761), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  5. Size: 512.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
     Unpadded size: 128.00M
     Extra memory due to padding: 384.00M (4.0x expansion)
     XLA label: %copy.246 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4757), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  6. Size: 512.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
     Unpadded size: 128.00M
     Extra memory due to padding: 384.00M (4.0x expansion)
     XLA label: %copy.245 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4753), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  7. Size: 512.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
     Unpadded size: 128.00M
     Extra memory due to padding: 384.00M (4.0x expansion)
     XLA label: %copy.244 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4747), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  8. Size: 512.00M
     Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n                                 precision=None ]"
     Shape: f32[64,8,512,512]{2,3,1,0:T(8,128)}
     Unpadded size: 512.00M
     XLA label: %fusion.2186 = (f32[64,8,512]{2,1,0:T(8,128)}, f32[64,8,512]{2,1,0:T(8,128)}, f32[64,8,512,512]{2,3,1,0:T(8,128)}) fusion(f32[64,8,512]{2,1,0:T(8,128)} %fusion.2753, pred[64,512]{1,0:T(8,128)E(32)} %get-tuple-element.4382, f32[64,8,512]{2,1,0:T(8,128)} %fu...
     Allocation type: HLO temp
     ==========================

  9. Size: 512.00M
     Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n                                 precision=None ]"
     Shape: f32[64,8,512,512]{2,3,1,0:T(8,128)}
     Unpadded size: 512.00M
     XLA label: %convolution-base-dilated.117.remat5 = f32[64,8,512,512]{2,3,1,0:T(8,128)} convolution(bf16[64,512,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.312, bf16[64,512,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.314), window={size=64x8 stride=63x7 lhs_dilate=64x8}, dim_labels...
     Allocation type: HLO temp
     ==========================

  10. Size: 256.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
     Unpadded size: 64.00M
     Extra memory due to padding: 192.00M (4.0x expansion)
     XLA label: %reshape.4751 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3020), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  11. Size: 256.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
     Unpadded size: 64.00M
     Extra memory due to padding: 192.00M (4.0x expansion)
     XLA label: %reshape.4755 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3021), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  12. Size: 256.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
     Unpadded size: 64.00M
     Extra memory due to padding: 192.00M (4.0x expansion)
     XLA label: %reshape.4759 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3022), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  13. Size: 256.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
     Unpadded size: 64.00M
     Extra memory due to padding: 192.00M (4.0x expansion)
     XLA label: %reshape.4763 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3023), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  14. Size: 256.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
     Unpadded size: 64.00M
     Extra memory due to padding: 192.00M (4.0x expansion)
     XLA label: %reshape.4767 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3024), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  15. Size: 256.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
     Unpadded size: 64.00M
     Extra memory due to padding: 192.00M (4.0x expansion)
     XLA label: %reshape.4771 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3025), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  16. Size: 128.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)}
     Unpadded size: 32.00M
     Extra memory due to padding: 96.00M (4.0x expansion)
     XLA label: %reshape.4740 = pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)} reshape(pred[33554432]{0:T(1024)E(32)} %fusion.2426), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  17. Size: 128.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)}
     Unpadded size: 32.00M
     Extra memory due to padding: 96.00M (4.0x expansion)
     XLA label: %reshape.4745 = pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)} reshape(pred[33554432]{0:T(1024)E(32)} %fusion.2431), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  18. Size: 128.00M
     Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
     Shape: pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)}
     Unpadded size: 32.00M
     Extra memory due to padding: 96.00M (4.0x expansion)
     XLA label: %reshape.4791 = pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)} reshape(pred[33554432]{0:T(1024)E(32)} %fusion.2427), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
     Allocation type: HLO temp
     ==========================

  19. Size: 128.00M
     Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n                                 precision=None ]"
     Shape: f32[64,8,128,512]{3,2,1,0:T(8,128)}
     Unpadded size: 128.00M
     XLA label: %fusion.4304 = (f32[64,8,128]{2,1,0:T(8,128)}, f32[64,8,128,512]{3,2,1,0:T(8,128)}) fusion(pred[64,512]{1,0:T(8,128)E(32)} %get-tuple-element.4384, bf16[64,128,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.85, bf16[64,512,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.389)...
     Allocation type: HLO temp
     ==========================

  20. Size: 128.00M
     Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n                                 precision=None ]"
     Shape: f32[64,8,128,512]{3,2,1,0:T(8,128)}
     Unpadded size: 128.00M
     XLA label: %fusion.4305 = (f32[64,8,128]{2,1,0:T(8,128)}, f32[64,8,128,512]{3,2,1,0:T(8,128)}) fusion(pred[64,512]{1,0:T(8,128)E(32)} %get-tuple-element.4384, bf16[64,128,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.464, bf16[64,512,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.466...
     Allocation type: HLO temp
     ==========================
@lukaszkaiser
Copy link
Contributor

We have improved memory use since that release, so it would be a good idea to try again. But please remember that a TPU core on colab has only 8GB of memory - so a large batch may have trouble fitting with Transformer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants