Skip to content

Commit

Permalink
Update the test mps.
Browse files Browse the repository at this point in the history
Remove torch._six from test_mps (#326)

Fix test_zero_grad() (#330)

Fix bilinear backward pass (#331)

* Fix bilinear backward pass

* Remove comment

Update macOS 12 blocklist (#323)

* Update macOS 12 blocklist
- move sum, masked.var, mul to low precision list
- unblock them from running

* - mark __rdiv__ failures as accumulate error exceeds atol/rtol

Fix nn.functional.embedding grad (#335)

- casting the input tensor to float32 and cast back the output tensor
- unblock the test

Fix prelu backward (#334)

Reduction cast f16 to f32 only on macOS 12 (#332)

- unblock rdiv float16

Fix trace op (#340)

- give warnings of converting int64 for reduction ops
- use cast tensor for reduction sum on trace
- unblock trace from running

Update random result list (#339)

* - move nn.functional.feature_alpha_dropoutwith_train, normalnumber_mean, new_empty_strided to expected failures

* - update new_empty_strided

---------

Co-authored-by: Kulin Seth <kulin_seth@apple.com>

Enable int8 in TestConsistency (#347)

Dev/skotapati/copy broadcasting (#350)

* Handle broadcasting by expanding src tensor in Copy.mm

* Unblock linalg_matrix_power

* Improved formatting

Add the functionality to dump MPS ops.

1. DUMP_MPS_OPS to use LoggingTensor to dump out the ATen ops.
2. Skip running the EXPECTTEST list, as some tests are still
   seg-faulting

Fix lintrunner errors (#353)

* Fix lintrunner errors

* - move normal_in_place to random result list

Fixed the test_mps.

Test mps is updated.
  • Loading branch information
kulinseth authored and skotapati committed Apr 6, 2023
1 parent b0829ba commit 14935c0
Show file tree
Hide file tree
Showing 3 changed files with 1,519 additions and 577 deletions.
102 changes: 102 additions & 0 deletions test/cuda_results.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
ConsistencyTest: {
nn.functional.conv_transpose2d:
[[[7.399066925048828, 4.4053635597229, -25.85348129272461,
58.88909149169922, -88.75193786621094, -18.98126983642578, 9.437820434570312],
[-59.78305435180664, -65.34088134765625, -108.04747009277344, 196.6062469482422,
71.39350891113281, 37.8786735534668, -69.55322265625], [92.78504943847656,
91.24403381347656, -94.33301544189453, 9.261059761047363, -182.10206604003906,
141.4270477294922, 146.89010620117188], [-14.363212585449219, 43.454036712646484,
-76.1098403930664, 242.9479522705078, 198.1458282470703, -49.77315139770508,
5.891449451446533], [-43.56822967529297, 4.782844066619873, -29.526945114135742,
65.15388488769531, 161.29757690429688, 118.60847473144531, 27.08570671081543],
[68.29853057861328, -11.507468223571777, 2.044086217880249, 11.003862380981445,
34.993282318115234, -21.256723403930664, 91.49512481689453], [-70.4466781616211,
69.04386138916016, 7.764842987060547, 7.61972713470459, -28.99899673461914,
54.575748443603516, -5.762258052825928]], [[-36.238487243652344, 37.29551696777344,
-22.012331008911133, -30.1353702545166, 33.82851028442383, 33.00322341918945,
2.7218000888824463], [-7.999058246612549, 122.72489929199219, -1.0639530420303345,
2.9564287662506104, -143.1276092529297, -110.75650024414062, 48.0764274597168],
[-91.0599136352539, -11.656601905822754, 69.62447357177734, 88.12522888183594,
337.3008728027344, -76.9416732788086, -110.24406433105469], [-108.1512451171875,
98.42401123046875, 142.46144104003906, -127.48089599609375, -3.367496967315674,
86.82833099365234, 86.29623413085938], [-14.339198112487793, -52.287410736083984,
171.43614196777344, 200.14817810058594, 200.35476684570312, -189.4150390625,
-46.86980056762695], [30.196495056152344, 25.22877311706543, 95.29426574707031,
4.455311298370361, 118.48747253417969, 87.11080932617188, -83.6124038696289],
[-2.5434072017669678, 91.8791732788086, -10.615175247192383, -12.58531379699707,
-49.3439826965332, 33.37324523925781, -5.983145713806152]], [[4.551003932952881,
15.84842586517334, -46.354671478271484, 14.721636772155762, 39.01048278808594,
49.70054244995117, -18.268564224243164], [16.728954315185547, 129.43505859375,
-4.6139116287231445, -3.382319688796997, -238.76353454589844, 13.42194938659668,
40.393280029296875], [-2.335604429244995, -85.94283294677734, -142.2253875732422,
135.27537536621094, 18.01512336730957, -26.331714630126953, -33.35443878173828],
[-79.17593383789062, -93.72674560546875, -110.94194030761719, -61.455223083496094,
6.811624526977539, 129.06478881835938, 12.435402870178223], [10.859378814697266,
41.3059196472168, 143.55824279785156, -41.754737854003906, -235.32406616210938,
-70.98460388183594, 130.46929931640625], [193.57574462890625, -142.5060272216797,
-102.45012664794922, 124.68048095703125, 136.05215454101562, -9.650590896606445,
-45.59521484375], [-37.829593658447266, 39.12519454956055, 9.293094635009766,
-18.8004093170166, -0.7294210195541382, 51.884910583496094, 36.15913391113281]],
[[-15.651233673095703, 16.31340980529785, -26.752052307128906, 6.281721115112305,
43.765541076660156, -13.097319602966309, -30.443206787109375], [10.67841911315918,
66.1829605102539, -9.394262313842773, -131.45101928710938, -38.621002197265625,
65.9507064819336, 48.76960372924805], [-76.0918197631836, -9.108996391296387,
13.64936637878418, 96.7411880493164, 124.2474365234375, -111.50318145751953,
-42.397071838378906], [-83.31562805175781, 32.27967071533203, 250.08163452148438,
58.24131393432617, 129.95318603515625, -10.683560371398926, -123.84668731689453],
[-11.536887168884277, -15.220125198364258, 197.18821716308594, -31.680112838745117,
-81.35874938964844, 157.96974182128906, 105.61251831054688], [78.15926361083984,
-84.49744415283203, -73.91180419921875, 86.370361328125, 77.87918090820312,
55.3555908203125, -7.273794651031494], [25.232547760009766, 30.352109909057617,
53.722267150878906, 44.87421798706055, 44.618812561035156, 4.511796951293945,
9.039834976196289]]]
}
UnitTest: {
norm:
[
{
dtype: f16,
args: [[[ 8.9453, 4.0859, 0.1230, 2.1367, -5.0000],
[ 7.2773, -4.6953, -3.5586, 8.2812, -0.8789],
[ 0.7119, -1.4854, 6.8633, -7.9805, -3.6562],
[-1.0195, -7.2695, -0.0264, -3.5078, -0.2900],
[ 8.7656, 5.8984, -2.3125, -0.0352, 5.2812]],],
params: [0.5,],
res: [2000.]
},
{
dtype: f16,
args: [[[[ 8.9219, 3.0508, -3.0234, -5.6250, -5.3516],
[-5.8906, 5.2109, -7.2500, 7.3047, -0.1846],
[-2.1367, -8.8047, -3.4727, -3.0859, 4.9062],
[ 2.1797, -8.5078, 6.1445, -5.0547, 2.8828],
[-2.6191, 4.6680, -4.1758, 8.7734, -5.4844]],

[[-5.8984, 7.3281, -7.3672, -0.0879, 7.0039],
[ 2.0117, -6.4258, 8.6250, 2.5137, -2.2676],
[-7.2578, 1.6875, 7.8750, 7.5078, 0.8350],
[-4.8164, -3.6914, -3.9199, 4.9219, -4.6680],
[ 5.0547, -7.1289, 2.3633, 3.7793, -7.4375]],

[[-8.6953, -3.8750, 0.8965, -4.4453, 6.1328],
[ 8.6719, 2.5586, -3.0664, -7.7891, 2.5234],
[ 5.8008, 0.5977, 4.9219, 3.0156, 3.6211],
[-6.0898, -3.4883, 2.6543, 7.1992, 5.9414],
[-3.6035, 8.3906, 2.2070, -1.1162, 7.2852]],

[[-2.4531, -2.9180, 6.2422, -6.3711, -8.3516],
[ 3.3398, -8.5078, -8.9375, -2.0312, -4.3320],
[-1.4326, -4.5000, -0.3252, -6.8555, -8.2969],
[ 5.8438, 5.6094, -6.6797, -0.0439, 3.6035],
[ 4.5859, 7.1016, -0.8086, 5.6953, 0.5098]],

[[ 3.0859, 4.4844, 0.6152, 7.9609, -7.6562],
[-0.7998, -3.4023, 5.7734, -2.4785, 5.9219],
[ 7.1094, 1.4502, -7.1289, 4.7188, -4.8359],
[ 2.7422, -1.9512, 5.6602, -3.6387, -8.6953],
[-4.6953, 0.2900, 2.7148, -0.0176, 7.6992]]],],
params: [1.5],
res: [125.2500]
},
],
}
Loading

0 comments on commit 14935c0

Please sign in to comment.